From 87e459ac4ec2ac9fc25bce2a817ac820b8acbbfe Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 30 Oct 2018 17:51:31 +0000 Subject: [PATCH 01/67] [FEATURE][ML] ML data frame analytics --- .../xpack/core/XPackClientPlugin.java | 6 +- .../core/ml/action/RunAnalyticsAction.java | 109 ++++++++++ .../xpack/ml/MachineLearning.java | 21 +- .../action/TransportRunAnalyticsAction.java | 97 +++++++++ .../ml/analytics/DataFrameDataExtractor.java | 193 ++++++++++++++++++ .../DataFrameDataExtractorContext.java | 33 +++ .../DataFrameDataExtractorFactory.java | 91 +++++++++ .../analytics/process/AnalyticsBuilder.java | 43 ++++ .../analytics/process/AnalyticsProcess.java | 11 + .../process/AnalyticsProcessFactory.java | 20 ++ .../process/AnalyticsProcessManager.java | 80 ++++++++ .../process/NativeAnalyticsProcess.java | 34 +++ .../NativeAnalyticsProcessFactory.java | 75 +++++++ .../extractor/fields/ExtractedFields.java | 9 + .../analytics/RestRunAnalyticsAction.java | 38 ++++ 15 files changed, 855 insertions(+), 5 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/RunAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/analytics/RestRunAnalyticsAction.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 21bd005ac5b7c..2ddf34b35464c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -38,12 +38,12 @@ import org.elasticsearch.transport.Transport; import org.elasticsearch.xpack.core.action.XPackInfoAction; import org.elasticsearch.xpack.core.action.XPackUsageAction; +import org.elasticsearch.xpack.core.beats.BeatsFeatureSetUsage; import org.elasticsearch.xpack.core.ccr.AutoFollowMetadata; import org.elasticsearch.xpack.core.deprecation.DeprecationInfoAction; import org.elasticsearch.xpack.core.graph.GraphFeatureSetUsage; import org.elasticsearch.xpack.core.graph.action.GraphExploreAction; import org.elasticsearch.xpack.core.logstash.LogstashFeatureSetUsage; -import org.elasticsearch.xpack.core.beats.BeatsFeatureSetUsage; import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.action.CloseJobAction; @@ -85,6 +85,7 @@ import org.elasticsearch.xpack.core.ml.action.PutFilterAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; @@ -136,8 +137,8 @@ import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.ExceptExpression; import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.FieldExpression; import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.RoleMapperExpression; -import org.elasticsearch.xpack.core.security.authz.privilege.ConditionalClusterPrivileges; import org.elasticsearch.xpack.core.security.authz.privilege.ConditionalClusterPrivilege; +import org.elasticsearch.xpack.core.security.authz.privilege.ConditionalClusterPrivileges; import org.elasticsearch.xpack.core.security.transport.netty4.SecurityNetty4Transport; import org.elasticsearch.xpack.core.ssl.SSLService; import org.elasticsearch.xpack.core.ssl.action.GetCertificateInfoAction; @@ -267,6 +268,7 @@ public List> getClientActions() { PostCalendarEventsAction.INSTANCE, PersistJobAction.INSTANCE, FindFileStructureAction.INSTANCE, + RunAnalyticsAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/RunAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/RunAnalyticsAction.java new file mode 100644 index 0000000000000..0e2a4eb15eb04 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/RunAnalyticsAction.java @@ -0,0 +1,109 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class RunAnalyticsAction extends Action { + + public static final RunAnalyticsAction INSTANCE = new RunAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/analytics/run"; + + private RunAnalyticsAction() { + super(NAME); + } + + @Override + public AcknowledgedResponse newResponse() { + return new AcknowledgedResponse(); + } + + public static class Request extends ActionRequest implements ToXContentObject { + + private String index; + + public Request(String index) { + this.index = index; + } + + public Request(StreamInput in) throws IOException { + readFrom(in); + } + + public Request() { + } + + public String getIndex() { + return index; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + index = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(index); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("index", index); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(index); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + RunAnalyticsAction.Request other = (RunAnalyticsAction.Request) obj; + return Objects.equals(index, other.index); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client, RunAnalyticsAction action) { + super(client, action, new Request()); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 738f5a9e1a47d..a1372382c9ccb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.core.ml.action.PutFilterAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; @@ -149,6 +150,7 @@ import org.elasticsearch.xpack.ml.action.TransportPutFilterAction; import org.elasticsearch.xpack.ml.action.TransportPutJobAction; import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction; +import org.elasticsearch.xpack.ml.action.TransportRunAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportStopDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportUpdateCalendarJobAction; @@ -159,6 +161,9 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; +import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessFactory; +import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessManager; +import org.elasticsearch.xpack.ml.analytics.process.NativeAnalyticsProcessFactory; import org.elasticsearch.xpack.ml.datafeed.DatafeedJobBuilder; import org.elasticsearch.xpack.ml.datafeed.DatafeedManager; import org.elasticsearch.xpack.ml.job.JobManager; @@ -183,6 +188,7 @@ import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction; import org.elasticsearch.xpack.ml.rest.RestFindFileStructureAction; import org.elasticsearch.xpack.ml.rest.RestMlInfoAction; +import org.elasticsearch.xpack.ml.rest.analytics.RestRunAnalyticsAction; import org.elasticsearch.xpack.ml.rest.calendar.RestDeleteCalendarAction; import org.elasticsearch.xpack.ml.rest.calendar.RestDeleteCalendarEventAction; import org.elasticsearch.xpack.ml.rest.calendar.RestDeleteCalendarJobAction; @@ -373,6 +379,7 @@ public Collection createComponents(Client client, ClusterService cluster AutodetectProcessFactory autodetectProcessFactory; NormalizerProcessFactory normalizerProcessFactory; + AnalyticsProcessFactory analyticsProcessFactory; if (MachineLearningField.AUTODETECT_PROCESS.get(settings) && MachineLearningFeatureSet.isRunningOnMlPlatform(true)) { try { NativeController nativeController = NativeControllerHolder.getNativeController(environment); @@ -387,6 +394,7 @@ public Collection createComponents(Client client, ClusterService cluster client, clusterService); normalizerProcessFactory = new NativeNormalizerProcessFactory(environment, nativeController); + analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, nativeController); } catch (IOException e) { // This also should not happen in production, as the MachineLearningFeatureSet should have // hit the same error first and brought down the node with a friendlier error message @@ -397,6 +405,7 @@ public Collection createComponents(Client client, ClusterService cluster new BlackHoleAutodetectProcess(job.getId()); // factor of 1.0 makes renormalization a no-op normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0); + analyticsProcessFactory = (jobId, executorService) -> null; } NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory, threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)); @@ -417,6 +426,9 @@ public Collection createComponents(Client client, ClusterService cluster // run node startup tasks autodetectProcessManager.onNodeStartup(); + AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, environment, threadPool, + analyticsProcessFactory); + return Arrays.asList( mlLifeCycleService, jobResultsProvider, @@ -426,7 +438,8 @@ public Collection createComponents(Client client, ClusterService cluster jobDataCountsPersister, datafeedManager, auditor, - new MlAssignmentNotifier(settings, auditor, clusterService) + new MlAssignmentNotifier(settings, auditor, clusterService), + analyticsProcessManager ); } @@ -509,7 +522,8 @@ public List getRestHandlers(Settings settings, RestController restC new RestPutCalendarJobAction(settings, restController), new RestGetCalendarEventsAction(settings, restController), new RestPostCalendarEventAction(settings, restController), - new RestFindFileStructureAction(settings, restController) + new RestFindFileStructureAction(settings, restController), + new RestRunAnalyticsAction(settings, restController) ); } @@ -567,7 +581,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(GetCalendarEventsAction.INSTANCE, TransportGetCalendarEventsAction.class), new ActionHandler<>(PostCalendarEventsAction.INSTANCE, TransportPostCalendarEventsAction.class), new ActionHandler<>(PersistJobAction.INSTANCE, TransportPersistJobAction.class), - new ActionHandler<>(FindFileStructureAction.INSTANCE, TransportFindFileStructureAction.class) + new ActionHandler<>(FindFileStructureAction.INSTANCE, TransportFindFileStructureAction.class), + new ActionHandler<>(RunAnalyticsAction.INSTANCE, TransportRunAnalyticsAction.class) ); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java new file mode 100644 index 0000000000000..6d713552832f9 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java @@ -0,0 +1,97 @@ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.env.Environment; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessManager; + +import java.util.Collections; +import java.util.Map; +import java.util.function.Supplier; + +public class TransportRunAnalyticsAction extends HandledTransportAction { + + private final TransportService transportService; + private final ThreadPool threadPool; + private final Client client; + private final ClusterService clusterService; + private final Environment environment; + private final AnalyticsProcessManager analyticsProcessManager; + + + @Inject + public TransportRunAnalyticsAction(Settings settings, ThreadPool threadPool, TransportService transportService, + ActionFilters actionFilters, Client client, ClusterService clusterService, Environment environment, + AnalyticsProcessManager analyticsProcessManager) { + super(settings, RunAnalyticsAction.NAME, transportService, actionFilters, + (Supplier) RunAnalyticsAction.Request::new); + this.transportService = transportService; + this.threadPool = threadPool; + this.client = client; + this.clusterService = clusterService; + this.environment = environment; + this.analyticsProcessManager = analyticsProcessManager; + } + + @Override + protected void doExecute(Task task, RunAnalyticsAction.Request request, ActionListener listener) { + DiscoveryNode localNode = clusterService.localNode(); + if (isMlNode(localNode)) { + runPipelineAnalytics(request, listener); + return; + } + + ClusterState clusterState = clusterService.state(); + for (DiscoveryNode node : clusterState.getNodes()) { + if (isMlNode(node)) { + transportService.sendRequest(node, actionName, request, + new ActionListenerResponseHandler<>(listener, inputStream -> { + AcknowledgedResponse response = new AcknowledgedResponse(); + response.readFrom(inputStream); + return response; + })); + return; + } + } + listener.onFailure(ExceptionsHelper.badRequestException("No ML node to run on")); + } + + private boolean isMlNode(DiscoveryNode node) { + Map nodeAttributes = node.getAttributes(); + String enabled = nodeAttributes.get(MachineLearning.ML_ENABLED_NODE_ATTR); + return Boolean.valueOf(enabled); + } + + private void runPipelineAnalytics(RunAnalyticsAction.Request request, ActionListener listener) { + String jobId = "ml-analytics-" + request.getIndex(); + + ActionListener dataExtractorFactoryListener = ActionListener.wrap( + dataExtractorFactory -> { + DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(); + analyticsProcessManager.processData(jobId, dataExtractor); + listener.onResponse(new AcknowledgedResponse(true)); + }, + listener::onFailure + ); + + DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), request.getIndex(), dataExtractorFactoryListener); + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java new file mode 100644 index 0000000000000..22f8befc55cd2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java @@ -0,0 +1,193 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.search.ClearScrollAction; +import org.elasticsearch.action.search.ClearScrollRequest; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchScrollAction; +import org.elasticsearch.action.search.SearchScrollRequestBuilder; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.datafeed.extractor.ExtractorUtils; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * An implementation that extracts data from elasticsearch using search and scroll on a client. + * It supports safe and responsive cancellation by continuing the scroll until a new timestamp + * is seen. + * Note that this class is NOT thread-safe. + */ +public class DataFrameDataExtractor { + + private static final Logger LOGGER = LogManager.getLogger(DataFrameDataExtractor.class); + private static final TimeValue SCROLL_TIMEOUT = new TimeValue(30, TimeUnit.MINUTES); + private static final String EMPTY_STRING = ""; + + private final Client client; + private final DataFrameDataExtractorContext context; + private String scrollId; + private boolean isCancelled; + private boolean hasNext; + private boolean searchHasShardFailure; + + DataFrameDataExtractor(Client client, DataFrameDataExtractorContext context) { + this.client = Objects.requireNonNull(client); + this.context = Objects.requireNonNull(context); + hasNext = true; + searchHasShardFailure = false; + } + + public boolean hasNext() { + return hasNext; + } + + public boolean isCancelled() { + return isCancelled; + } + + public void cancel() { + isCancelled = true; + } + + public Optional> next() throws IOException { + if (!hasNext()) { + throw new NoSuchElementException(); + } + Optional> records = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll()); + if (!records.isPresent()) { + hasNext = false; + } + return records; + } + + protected List initScroll() throws IOException { + LOGGER.debug("[{}] Initializing scroll", "analytics"); + SearchResponse searchResponse = executeSearchRequest(buildSearchRequest()); + LOGGER.debug("[{}] Search response was obtained", context.jobId); + return processSearchResponse(searchResponse); + } + + protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) { + return ClientHelper.executeWithHeaders(context.headers, ClientHelper.ML_ORIGIN, client, searchRequestBuilder::get); + } + + private SearchRequestBuilder buildSearchRequest() { + SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) + .setScroll(SCROLL_TIMEOUT) + .addSort("_doc", SortOrder.ASC) + .setIndices(context.indices) + .setSize(context.scrollSize) + .setQuery(context.query) + .setFetchSource(false); + + for (ExtractedField docValueField : context.extractedFields.getDocValueFields()) { + searchRequestBuilder.addDocValueField(docValueField.getName(), docValueField.getDocValueFormat()); + } + + return searchRequestBuilder; + } + + private List processSearchResponse(SearchResponse searchResponse) throws IOException { + + if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) { + LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId); + markScrollAsErrored(); + return initScroll(); + } + + ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse); + scrollId = searchResponse.getScrollId(); + if (searchResponse.getHits().getHits().length == 0) { + hasNext = false; + clearScroll(scrollId); + return null; + } + + SearchHit[] hits = searchResponse.getHits().getHits(); + List records = new ArrayList<>(hits.length); + for (SearchHit hit : hits) { + if (isCancelled) { + hasNext = false; + clearScroll(scrollId); + break; + } + records.add(toStringArray(hit)); + } + return records; + } + + private String[] toStringArray(SearchHit hit) { + String[] result = new String[context.extractedFields.getAllFields().size()]; + for (int i = 0; i < result.length; ++i) { + ExtractedField field = context.extractedFields.getAllFields().get(i); + Object[] values = field.value(hit); + result[i] = (values.length == 1) ? Objects.toString(values[0]) : EMPTY_STRING; + } + return result; + } + + private List continueScroll() throws IOException { + LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId); + SearchResponse searchResponse = executeSearchScrollRequest(scrollId); + LOGGER.debug("[{}] Search response was obtained", context.jobId); + return processSearchResponse(searchResponse); + } + + private void markScrollAsErrored() { + // This could be a transient error with the scroll Id. + // Reinitialise the scroll and try again but only once. + resetScroll(); + searchHasShardFailure = true; + } + + protected SearchResponse executeSearchScrollRequest(String scrollId) { + return ClientHelper.executeWithHeaders(context.headers, ClientHelper.ML_ORIGIN, client, + () -> new SearchScrollRequestBuilder(client, SearchScrollAction.INSTANCE) + .setScroll(SCROLL_TIMEOUT) + .setScrollId(scrollId) + .get()); + } + + private void resetScroll() { + clearScroll(scrollId); + scrollId = null; + } + + private void clearScroll(String scrollId) { + if (scrollId != null) { + ClearScrollRequest request = new ClearScrollRequest(); + request.addScrollId(scrollId); + ClientHelper.executeWithHeaders(context.headers, ClientHelper.ML_ORIGIN, client, + () -> client.execute(ClearScrollAction.INSTANCE, request).actionGet()); + } + } + + public List getFieldNames() { + return context.extractedFields.getAllFields().stream().map(ExtractedField::getAlias).collect(Collectors.toList()); + } + + public String[] getFieldNamesArray() { + List fieldNames = getFieldNames(); + return fieldNames.toArray(new String[fieldNames.size()]); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java new file mode 100644 index 0000000000000..2acba64197c4b --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class DataFrameDataExtractorContext { + + final String jobId; + final ExtractedFields extractedFields; + final String[] indices; + final QueryBuilder query; + final int scrollSize; + final Map headers; + + DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List indices, QueryBuilder query, int scrollSize, + Map headers) { + this.jobId = Objects.requireNonNull(jobId); + this.extractedFields = Objects.requireNonNull(extractedFields); + this.indices = indices.toArray(new String[indices.size()]); + this.query = Objects.requireNonNull(query); + this.scrollSize = scrollSize; + this.headers = headers; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java new file mode 100644 index 0000000000000..b67e744ad5608 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction; +import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest; +import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class DataFrameDataExtractorFactory { + + /** + * Fields to ignore. These are mostly internal meta fields. + */ + private static final List IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", + "_source", "_type", "_uid", "_version", "_feature", "_ignored"); + + private final Client client; + private final String index; + private final ExtractedFields extractedFields; + + private DataFrameDataExtractorFactory(Client client, String index, ExtractedFields extractedFields) { + this.client = Objects.requireNonNull(client); + this.index = Objects.requireNonNull(index); + this.extractedFields = Objects.requireNonNull(extractedFields); + } + + public DataFrameDataExtractor newExtractor() { + DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( + "ml-analytics-" + index, + extractedFields, + Arrays.asList(index), + QueryBuilders.matchAllQuery(), + 1000, + Collections.emptyMap()); + return new DataFrameDataExtractor(client, context); + } + + public static void create(Client client, Map headers, String index, + ActionListener listener) { + + // Step 2. Contruct the factory and notify listener + ActionListener fieldCapabilitiesHandler = ActionListener.wrap( + fieldCapabilitiesResponse -> { + listener.onResponse(new DataFrameDataExtractorFactory(client, index, detectExtractedFields(fieldCapabilitiesResponse))); + }, e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " + + ((IndexNotFoundException) e).getIndex() + " does not exist")); + } else { + listener.onFailure(e); + } + } + ); + + // Step 1. Get field capabilities necessary to build the information of how to extract fields + FieldCapabilitiesRequest fieldCapabilitiesRequest = new FieldCapabilitiesRequest(); + fieldCapabilitiesRequest.indices(index); + fieldCapabilitiesRequest.fields("*"); + ClientHelper.executeWithHeaders(headers, ClientHelper.ML_ORIGIN, client, () -> { + client.execute(FieldCapabilitiesAction.INSTANCE, fieldCapabilitiesRequest, fieldCapabilitiesHandler); + // This response gets discarded - the listener handles the real response + return null; + }); + } + + private static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapabilitiesResponse) { + Set fields = fieldCapabilitiesResponse.get().keySet(); + fields.removeAll(IGNORE_FIELDS); + return ExtractedFields.build(new ArrayList<>(fields), Collections.emptySet(), fieldCapabilitiesResponse) + .filterFields(ExtractedField.ExtractionMethod.DOC_VALUE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java new file mode 100644 index 0000000000000..4f9627a402b64 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public class AnalyticsBuilder { + + public static final String ANALYTICS = "data_frame_analyzer"; + private static final String ANALYTICS_PATH = "./" + ANALYTICS; + + private static final String LENGTH_ENCODED_INPUT_ARG = "--lengthEncodedInput"; + + private final NativeController nativeController; + private final ProcessPipes processPipes; + + public AnalyticsBuilder(NativeController nativeController, ProcessPipes processPipes) { + this.nativeController = Objects.requireNonNull(nativeController); + this.processPipes = Objects.requireNonNull(processPipes); + } + + public void build() throws IOException { + List command = buildAnalyticsCommand(); + processPipes.addArgs(command); + nativeController.startProcess(command); + } + + List buildAnalyticsCommand() { + List command = new ArrayList<>(); + command.add(ANALYTICS_PATH); + command.add(LENGTH_ENCODED_INPUT_ARG); + return command; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java new file mode 100644 index 0000000000000..e9ae8ebacf86f --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java @@ -0,0 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.xpack.ml.process.NativeProcess; + +public interface AnalyticsProcess extends NativeProcess { +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java new file mode 100644 index 0000000000000..330738cb69f7e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import java.util.concurrent.ExecutorService; + +public interface AnalyticsProcessFactory { + + /** + * Create an implementation of {@link AnalyticsProcess} + * + * @param jobId The job id + * @param executorService Executor service used to start the async tasks a job needs to operate the analytical process + * @return The process + */ + AnalyticsProcess createAnalyticsProcess(String jobId, ExecutorService executorService); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java new file mode 100644 index 0000000000000..b4aa413e13ce1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.client.Client; +import org.elasticsearch.env.Environment; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +public class AnalyticsProcessManager { + + private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class); + + private final Client client; + private final Environment environment; + private final ThreadPool threadPool; + private final AnalyticsProcessFactory processFactory; + + public AnalyticsProcessManager(Client client, Environment environment, ThreadPool threadPool, + AnalyticsProcessFactory analyticsProcessFactory) { + this.client = Objects.requireNonNull(client); + this.environment = Objects.requireNonNull(environment); + this.threadPool = Objects.requireNonNull(threadPool); + this.processFactory = Objects.requireNonNull(analyticsProcessFactory); + } + + public void processData(String jobId, DataFrameDataExtractor dataExtractor) { + threadPool.generic().execute(() -> { + AnalyticsProcess process = createProcess(jobId); + try { + // Fake header + process.writeRecord(dataExtractor.getFieldNamesArray()); + + while (dataExtractor.hasNext()) { + Optional> records = dataExtractor.next(); + if (records.isPresent()) { + for (String[] record : records.get()) { + process.writeRecord(record); + } + } + } + process.flushStream(); + + LOGGER.debug("[{}] Closing process", jobId); + process.close(); + LOGGER.info("[{}] Closed process", jobId); + } catch (IOException e) { + + } finally { + try { + process.close(); + } catch (IOException e) { + LOGGER.error("[{}] Error closing data frame analyzer process", jobId); + } + } + }); + } + + private AnalyticsProcess createProcess(String jobId) { + ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); + AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, executorService); + if (process.isProcessAlive() == false) { + throw ExceptionsHelper.serverError("Failed to start analytics process"); + } + return process; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java new file mode 100644 index 0000000000000..12f41f83568dd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Path; +import java.util.List; + +public class NativeAnalyticsProcess extends AbstractNativeProcess implements AnalyticsProcess { + + private static final String NAME = "analytics"; + + protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream, + OutputStream processRestoreStream, int numberOfFields, List filesToDelete, + Runnable onProcessCrash) { + super(jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public void persistState() { + // Nothing to persist + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java new file mode 100644 index 0000000000000..125e8d12d90ec --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.internal.io.IOUtils; +import org.elasticsearch.env.Environment; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.utils.NamedPipeHelper; + +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.Objects; +import java.util.concurrent.ExecutorService; + +public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory { + + private static final Logger LOGGER = LogManager.getLogger(NativeAnalyticsProcessFactory.class); + + private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper(); + public static final Duration PROCESS_STARTUP_TIMEOUT = Duration.ofSeconds(10); + + private final Environment env; + private final NativeController nativeController; + + public NativeAnalyticsProcessFactory(Environment env, NativeController nativeController) { + this.env = Objects.requireNonNull(env); + this.nativeController = Objects.requireNonNull(nativeController); + } + + @Override + public AnalyticsProcess createAnalyticsProcess(String jobId, ExecutorService executorService) { + ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId, + true, false, true, true, false, false); + + createNativeProcess(jobId, processPipes); + + NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, processPipes.getLogStream().get(), + processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, 0, + Collections.emptyList(), () -> {}); + + + try { + analyticsProcess.start(executorService); + return analyticsProcess; + } catch (EsRejectedExecutionException e) { + try { + IOUtils.close(analyticsProcess); + } catch (IOException ioe) { + LOGGER.error("Can't close analytics", ioe); + } + throw e; + } + } + + private void createNativeProcess(String jobId, ProcessPipes processPipes) { + AnalyticsBuilder analyticsBuilder = new AnalyticsBuilder(nativeController, processPipes); + try { + analyticsBuilder.build(); + processPipes.connectStreams(PROCESS_STARTUP_TIMEOUT); + } catch (IOException e) { + String msg = "Failed to launch analytics for job " + jobId; + LOGGER.error(msg); + throw ExceptionsHelper.serverError(msg, e); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java index f9b2467fbcfd3..f64315ab3c28a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedFields.java @@ -47,6 +47,15 @@ public List getDocValueFields() { return docValueFields; } + /** + * Returns a new instance which only contains fields matching the given extraction method + * @param method the extraction method to filter fields on + * @return a new instance which only contains fields matching the given extraction method + */ + public ExtractedFields filterFields(ExtractedField.ExtractionMethod method) { + return new ExtractedFields(filterFields(method, allFields)); + } + private static List filterFields(ExtractedField.ExtractionMethod method, List fields) { return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList()); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/analytics/RestRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/analytics/RestRunAnalyticsAction.java new file mode 100644 index 0000000000000..feadce6ff302d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/analytics/RestRunAnalyticsAction.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.analytics; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +public class RestRunAnalyticsAction extends BaseRestHandler { + + public RestRunAnalyticsAction(Settings settings, RestController controller) { + super(settings); + controller.registerHandler(RestRequest.Method.POST, MachineLearning.BASE_PATH + "analytics/{index}/_run", this); + } + + @Override + public String getName() { + return "xpack_ml_run_analytics_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + RunAnalyticsAction.Request request = new RunAnalyticsAction.Request(restRequest.param("index")); + return channel -> { + client.execute(RunAnalyticsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + }; + } +} From a52eab6d14f7fe62840b9c738490a88e02360651 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 21 Nov 2018 13:57:07 +0000 Subject: [PATCH 02/67] [FEATURE][ML] Inline empty string in data frame data extractor --- .../xpack/ml/analytics/DataFrameDataExtractor.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java index 22f8befc55cd2..d0ff36b593d9d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java @@ -41,7 +41,6 @@ public class DataFrameDataExtractor { private static final Logger LOGGER = LogManager.getLogger(DataFrameDataExtractor.class); private static final TimeValue SCROLL_TIMEOUT = new TimeValue(30, TimeUnit.MINUTES); - private static final String EMPTY_STRING = ""; private final Client client; private final DataFrameDataExtractorContext context; @@ -141,7 +140,7 @@ private String[] toStringArray(SearchHit hit) { for (int i = 0; i < result.length; ++i) { ExtractedField field = context.extractedFields.getAllFields().get(i); Object[] values = field.value(hit); - result[i] = (values.length == 1) ? Objects.toString(values[0]) : EMPTY_STRING; + result[i] = (values.length == 1) ? Objects.toString(values[0]) : ""; } return result; } From 3dff025c885be32447ec7ae1e1d2883adfed2741 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 21 Nov 2018 13:57:32 +0000 Subject: [PATCH 03/67] [FEATURE][ML] Add TODO regarding renaming the job thread pool --- .../xpack/ml/analytics/process/AnalyticsProcessManager.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java index b4aa413e13ce1..4a24408dfc144 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java @@ -70,6 +70,7 @@ public void processData(String jobId, DataFrameDataExtractor dataExtractor) { } private AnalyticsProcess createProcess(String jobId) { + // TODO We should rename the thread pool to reflect its more general use now, e.g. JOB_PROCESS_THREAD_POOL_NAME ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, executorService); if (process.isProcessAlive() == false) { From 7b47107b4d7fbdd52aa6cd3459f167e2686f89a6 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 22 Nov 2018 16:25:04 +0000 Subject: [PATCH 04/67] [FEATURE][ML] Remove settings from run analytics action --- .../xpack/ml/action/TransportRunAnalyticsAction.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java index 6d713552832f9..570facb2b5756 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java @@ -10,7 +10,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.env.Environment; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -37,10 +36,10 @@ public class TransportRunAnalyticsAction extends HandledTransportAction) RunAnalyticsAction.Request::new); this.transportService = transportService; this.threadPool = threadPool; From 6e3f832bc62f5fba526d2321e23fdfdeab81adda Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 26 Nov 2018 16:45:27 +0000 Subject: [PATCH 05/67] [FEATURE][ML] Reindex data frame before starting analytics (#35835) With this commit before we start the analytics we first reindex the source data frame into a new index. Note we should maintain the settings and the mappings of the source index. --- .../action/TransportRunAnalyticsAction.java | 96 ++++++++++++++++++- .../ml/analytics/DataFrameDataExtractor.java | 2 +- .../xpack/ml/analytics/DataFrameFields.java | 13 +++ 3 files changed, 105 insertions(+), 6 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameFields.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java index 570facb2b5756..0c2bea17c7849 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java @@ -1,16 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ package org.elasticsearch.xpack.ml.action; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.admin.indices.create.CreateIndexAction; +import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; +import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetaData; +import org.elasticsearch.cluster.metadata.MappingMetaData; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.env.Environment; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.IndexSortConfig; +import org.elasticsearch.index.reindex.ReindexAction; +import org.elasticsearch.index.reindex.ReindexRequest; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -19,9 +37,13 @@ import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.analytics.DataFrameFields; import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessManager; +import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.Supplier; @@ -34,6 +56,17 @@ public class TransportRunAnalyticsAction extends HandledTransportAction INTERNAL_SETTINGS = Arrays.asList( + "index.creation_date", + "index.provided_name", + "index.uuid", + "index.version.created" + ); @Inject public TransportRunAnalyticsAction(ThreadPool threadPool, TransportService transportService, ActionFilters actionFilters, @@ -53,7 +86,7 @@ public TransportRunAnalyticsAction(ThreadPool threadPool, TransportService trans protected void doExecute(Task task, RunAnalyticsAction.Request request, ActionListener listener) { DiscoveryNode localNode = clusterService.localNode(); if (isMlNode(localNode)) { - runPipelineAnalytics(request, listener); + reindexDataframeAndStartAnalysis(request.getIndex(), listener); return; } @@ -78,8 +111,62 @@ private boolean isMlNode(DiscoveryNode node) { return Boolean.valueOf(enabled); } - private void runPipelineAnalytics(RunAnalyticsAction.Request request, ActionListener listener) { - String jobId = "ml-analytics-" + request.getIndex(); + private void reindexDataframeAndStartAnalysis(String index, ActionListener listener) { + final String destinationIndex = index + "_copy"; + + ActionListener copyIndexCreatedListener = ActionListener.wrap( + createIndexResponse -> { + ReindexRequest reindexRequest = new ReindexRequest(); + reindexRequest.setSourceIndices(index); + reindexRequest.setDestIndex(destinationIndex); + reindexRequest.setScript(new Script("ctx._source." + DataFrameFields.ID + " = ctx._id")); + client.execute(ReindexAction.INSTANCE, reindexRequest, ActionListener.wrap( + bulkResponse -> { + runPipelineAnalytics(destinationIndex, listener); + }, + listener::onFailure + )); + }, listener::onFailure + ); + + createDestinationIndex(index, destinationIndex, copyIndexCreatedListener); + } + + private void createDestinationIndex(String sourceIndex, String destinationIndex, ActionListener listener) { + IndexMetaData indexMetaData = clusterService.state().getMetaData().getIndices().get(sourceIndex); + if (indexMetaData == null) { + listener.onFailure(new IndexNotFoundException(sourceIndex)); + return; + } + + if (indexMetaData.getMappings().size() != 1) { + listener.onFailure(ExceptionsHelper.badRequestException("Does not support indices with multiple types")); + return; + } + + Settings.Builder settingsBuilder = Settings.builder().put(indexMetaData.getSettings()); + INTERNAL_SETTINGS.stream().forEach(settingsBuilder::remove); + settingsBuilder.put(IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey(), DataFrameFields.ID); + settingsBuilder.put(IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey(), SortOrder.ASC); + + CreateIndexRequest createIndexRequest = new CreateIndexRequest(destinationIndex, settingsBuilder.build()); + addDestinationIndexMappings(indexMetaData, createIndexRequest); + client.execute(CreateIndexAction.INSTANCE, createIndexRequest, listener); + } + + private static void addDestinationIndexMappings(IndexMetaData indexMetaData, CreateIndexRequest createIndexRequest) { + ImmutableOpenMap mappings = indexMetaData.getMappings(); + Map mappingsAsMap = mappings.valuesIt().next().sourceAsMap(); + Map properties = (Map) mappingsAsMap.get("properties"); + Map idCopyMapping = new HashMap<>(); + idCopyMapping.put("type", "keyword"); + properties.put(DataFrameFields.ID, idCopyMapping); + + createIndexRequest.mapping(mappings.keysIt().next(), mappingsAsMap); + } + + private void runPipelineAnalytics(String index, ActionListener listener) { + String jobId = "ml-analytics-" + index; ActionListener dataExtractorFactoryListener = ActionListener.wrap( dataExtractorFactory -> { @@ -90,7 +177,6 @@ private void runPipelineAnalytics(RunAnalyticsAction.Request request, ActionList listener::onFailure ); - DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), request.getIndex(), dataExtractorFactoryListener); + DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), index, dataExtractorFactoryListener); } - } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java index d0ff36b593d9d..f9d50f1c3343e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java @@ -93,7 +93,7 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest private SearchRequestBuilder buildSearchRequest() { SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) .setScroll(SCROLL_TIMEOUT) - .addSort("_doc", SortOrder.ASC) + .addSort(DataFrameFields.ID, SortOrder.ASC) .setIndices(context.indices) .setSize(context.scrollSize) .setQuery(context.query) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameFields.java new file mode 100644 index 0000000000000..8e7a8dd61a8ed --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameFields.java @@ -0,0 +1,13 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics; + +public final class DataFrameFields { + + public static final String ID = "_id_copy"; + + private DataFrameFields() {} +} From 3f49eef360605b19710279fbd6a58b743842dd54 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 27 Nov 2018 15:15:42 +0000 Subject: [PATCH 06/67] [FEATURE][ML] Write data frame configuration to process (#35914) --- .../xpack/ml/MachineLearning.java | 2 +- .../action/TransportRunAnalyticsAction.java | 20 +++++--- .../xpack/ml/analytics/DataFrameAnalysis.java | 32 ++++++++++++ .../ml/analytics/DataFrameDataExtractor.java | 21 ++++++++ .../analytics/process/AnalyticsBuilder.java | 35 ++++++++++++- .../process/AnalyticsProcessConfig.java | 50 +++++++++++++++++++ .../process/AnalyticsProcessFactory.java | 3 +- .../process/AnalyticsProcessManager.java | 16 ++++-- .../NativeAnalyticsProcessFactory.java | 22 +++++--- 9 files changed, 180 insertions(+), 21 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameAnalysis.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index d2e2a94fbec21..841fff2449f3b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -406,7 +406,7 @@ public Collection createComponents(Client client, ClusterService cluster new BlackHoleAutodetectProcess(job.getId()); // factor of 1.0 makes renormalization a no-op normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0); - analyticsProcessFactory = (jobId, executorService) -> null; + analyticsProcessFactory = (jobId, analyticsProcessConfig, executorService) -> null; } NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory, threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java index 0c2bea17c7849..2a6199486a2ac 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java @@ -10,6 +10,8 @@ import org.elasticsearch.action.admin.indices.create.CreateIndexAction; import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.admin.indices.refresh.RefreshAction; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.master.AcknowledgedResponse; @@ -25,6 +27,7 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexSortConfig; +import org.elasticsearch.index.reindex.BulkByScrollResponse; import org.elasticsearch.index.reindex.ReindexAction; import org.elasticsearch.index.reindex.ReindexRequest; import org.elasticsearch.script.Script; @@ -114,18 +117,23 @@ private boolean isMlNode(DiscoveryNode node) { private void reindexDataframeAndStartAnalysis(String index, ActionListener listener) { final String destinationIndex = index + "_copy"; + ActionListener reindexCompletedListener = ActionListener.wrap( + bulkResponse -> { + client.execute(RefreshAction.INSTANCE, new RefreshRequest(destinationIndex), ActionListener.wrap( + refreshResponse -> { + runPipelineAnalytics(destinationIndex, listener); + }, listener::onFailure + )); + }, listener::onFailure + ); + ActionListener copyIndexCreatedListener = ActionListener.wrap( createIndexResponse -> { ReindexRequest reindexRequest = new ReindexRequest(); reindexRequest.setSourceIndices(index); reindexRequest.setDestIndex(destinationIndex); reindexRequest.setScript(new Script("ctx._source." + DataFrameFields.ID + " = ctx._id")); - client.execute(ReindexAction.INSTANCE, reindexRequest, ActionListener.wrap( - bulkResponse -> { - runPipelineAnalytics(destinationIndex, listener); - }, - listener::onFailure - )); + client.execute(ReindexAction.INSTANCE, reindexRequest, reindexCompletedListener); }, listener::onFailure ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameAnalysis.java new file mode 100644 index 0000000000000..1b06e77e31b5a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameAnalysis.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; + +public class DataFrameAnalysis implements ToXContentObject { + + private static final ParseField NAME = new ParseField("name"); + + private final String name; + + public DataFrameAnalysis(String name) { + this.name = ExceptionsHelper.requireNonNull(name, NAME.getPreferredName()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NAME.getPreferredName(), name); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java index f9d50f1c3343e..c35ee278b4e8e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java @@ -189,4 +189,25 @@ public String[] getFieldNamesArray() { List fieldNames = getFieldNames(); return fieldNames.toArray(new String[fieldNames.size()]); } + + public DataSummary collectDataSummary() { + SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) + .setIndices(context.indices) + .setSize(0) + .setQuery(context.query); + + SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder); + return new DataSummary(searchResponse.getHits().getTotalHits(), context.extractedFields.getAllFields().size()); + } + + public static class DataSummary { + + public final long rows; + public final long cols; + + public DataSummary(long rows, long cols) { + this.rows = rows; + this.cols = cols; + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java index 4f9627a402b64..e2b81ff547cc6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java @@ -5,10 +5,19 @@ */ package org.elasticsearch.xpack.ml.analytics.process; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.env.Environment; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.ProcessPipes; import java.io.IOException; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -19,13 +28,21 @@ public class AnalyticsBuilder { private static final String ANALYTICS_PATH = "./" + ANALYTICS; private static final String LENGTH_ENCODED_INPUT_ARG = "--lengthEncodedInput"; + private static final String CONFIG_ARG = "--config="; + private final Environment env; private final NativeController nativeController; private final ProcessPipes processPipes; + private final AnalyticsProcessConfig config; + private final List filesToDelete; - public AnalyticsBuilder(NativeController nativeController, ProcessPipes processPipes) { + public AnalyticsBuilder(Environment env, NativeController nativeController, ProcessPipes processPipes, AnalyticsProcessConfig config, + List filesToDelete) { + this.env = Objects.requireNonNull(env); this.nativeController = Objects.requireNonNull(nativeController); this.processPipes = Objects.requireNonNull(processPipes); + this.config = Objects.requireNonNull(config); + this.filesToDelete = Objects.requireNonNull(filesToDelete); } public void build() throws IOException { @@ -34,10 +51,24 @@ public void build() throws IOException { nativeController.startProcess(command); } - List buildAnalyticsCommand() { + List buildAnalyticsCommand() throws IOException { List command = new ArrayList<>(); command.add(ANALYTICS_PATH); command.add(LENGTH_ENCODED_INPUT_ARG); + addConfigFile(command); return command; } + + private void addConfigFile(List command) throws IOException { + Path configFile = Files.createTempFile(env.tmpFile(), "analysis", ".conf"); + filesToDelete.add(configFile); + try (OutputStreamWriter osw = new OutputStreamWriter(Files.newOutputStream(configFile),StandardCharsets.UTF_8); + XContentBuilder jsonBuilder = JsonXContent.contentBuilder()) { + + config.toXContent(jsonBuilder, ToXContent.EMPTY_PARAMS); + osw.write(Strings.toString(jsonBuilder)); + } + + command.add(CONFIG_ARG + configFile.toString()); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java new file mode 100644 index 0000000000000..62e97189d0dbf --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.ml.analytics.DataFrameAnalysis; + +import java.io.IOException; +import java.util.Objects; + +public class AnalyticsProcessConfig implements ToXContentObject { + + private static final String ROWS = "rows"; + private static final String COLS = "cols"; + private static final String MEMORY_LIMIT = "memory_limit"; + private static final String THREADS = "threads"; + private static final String ANALYSIS = "analysis"; + + private final long rows; + private final long cols; + private final ByteSizeValue memoryLimit; + private final int threads; + private final DataFrameAnalysis analysis; + + + public AnalyticsProcessConfig(long rows, long cols, ByteSizeValue memoryLimit, int threads, DataFrameAnalysis analysis) { + this.rows = rows; + this.cols = cols; + this.memoryLimit = Objects.requireNonNull(memoryLimit); + this.threads = threads; + this.analysis = Objects.requireNonNull(analysis); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ROWS, rows); + builder.field(COLS, cols); + builder.field(MEMORY_LIMIT, memoryLimit.getBytes()); + builder.field(THREADS, threads); + builder.field(ANALYSIS, analysis); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java index 330738cb69f7e..d0eb7a414074f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java @@ -13,8 +13,9 @@ public interface AnalyticsProcessFactory { * Create an implementation of {@link AnalyticsProcess} * * @param jobId The job id + * @param analyticsProcessConfig The process configuration * @param executorService Executor service used to start the async tasks a job needs to operate the analytical process * @return The process */ - AnalyticsProcess createAnalyticsProcess(String jobId, ExecutorService executorService); + AnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig, ExecutorService executorService); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java index 4a24408dfc144..e0b939ad420cf 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java @@ -8,10 +8,13 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.client.Client; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.env.Environment; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.analytics.DataFrameAnalysis; import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; import java.io.IOException; @@ -39,7 +42,7 @@ public AnalyticsProcessManager(Client client, Environment environment, ThreadPoo public void processData(String jobId, DataFrameDataExtractor dataExtractor) { threadPool.generic().execute(() -> { - AnalyticsProcess process = createProcess(jobId); + AnalyticsProcess process = createProcess(jobId, dataExtractor); try { // Fake header process.writeRecord(dataExtractor.getFieldNamesArray()); @@ -69,13 +72,20 @@ public void processData(String jobId, DataFrameDataExtractor dataExtractor) { }); } - private AnalyticsProcess createProcess(String jobId) { + private AnalyticsProcess createProcess(String jobId, DataFrameDataExtractor dataExtractor) { // TODO We should rename the thread pool to reflect its more general use now, e.g. JOB_PROCESS_THREAD_POOL_NAME ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); - AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, executorService); + AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, createProcessConfig(dataExtractor), executorService); if (process.isProcessAlive() == false) { throw ExceptionsHelper.serverError("Failed to start analytics process"); } return process; } + + private AnalyticsProcessConfig createProcessConfig(DataFrameDataExtractor dataExtractor) { + DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + AnalyticsProcessConfig config = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, + new ByteSizeValue(1, ByteSizeUnit.GB), 1, new DataFrameAnalysis("outliers")); + return config; + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java index 125e8d12d90ec..7d8d1557ad967 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java @@ -16,8 +16,10 @@ import org.elasticsearch.xpack.ml.utils.NamedPipeHelper; import java.io.IOException; +import java.nio.file.Path; import java.time.Duration; -import java.util.Collections; +import java.util.ArrayList; +import java.util.List; import java.util.Objects; import java.util.concurrent.ExecutorService; @@ -37,15 +39,17 @@ public NativeAnalyticsProcessFactory(Environment env, NativeController nativeCon } @Override - public AnalyticsProcess createAnalyticsProcess(String jobId, ExecutorService executorService) { + public AnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig, + ExecutorService executorService) { + List filesToDelete = new ArrayList<>(); ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId, - true, false, true, true, false, false); + true, false, true, true, false, false); - createNativeProcess(jobId, processPipes); + createNativeProcess(jobId, analyticsProcessConfig, filesToDelete, processPipes); NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, processPipes.getLogStream().get(), - processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, 0, - Collections.emptyList(), () -> {}); + processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, 0, + filesToDelete, () -> {}); try { @@ -61,8 +65,10 @@ public AnalyticsProcess createAnalyticsProcess(String jobId, ExecutorService exe } } - private void createNativeProcess(String jobId, ProcessPipes processPipes) { - AnalyticsBuilder analyticsBuilder = new AnalyticsBuilder(nativeController, processPipes); + private void createNativeProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig, List filesToDelete, + ProcessPipes processPipes) { + AnalyticsBuilder analyticsBuilder = new AnalyticsBuilder(env, nativeController, processPipes, analyticsProcessConfig, + filesToDelete); try { analyticsBuilder.build(); processPipes.connectStreams(PROCESS_STARTUP_TIMEOUT); From 801665a66755735446fd9387c583d997dacd294f Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 28 Nov 2018 13:19:11 +0000 Subject: [PATCH 07/67] [FEATURE][ML] Only write numeric fields to data frame (#35961) --- .../action/TransportRunAnalyticsAction.java | 3 + .../DataFrameDataExtractorFactory.java | 41 ++++++- .../DataFrameDataExtractorFactoryTests.java | 115 ++++++++++++++++++ 3 files changed, 157 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java index 2a6199486a2ac..7d7b194de8ae6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java @@ -185,6 +185,9 @@ private void runPipelineAnalytics(String index, ActionListener IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", "_source", "_type", "_uid", "_version", "_feature", "_ignored"); + /** + * The types supported by data frames + */ + private static final Set COMPATIBLE_FIELD_TYPES; + + static { + Set compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values()) + .map(NumberFieldMapper.NumberType::typeName) + .collect(Collectors.toSet()); + compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module + + COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes); + } + private final Client client; private final String index; private final ExtractedFields extractedFields; @@ -82,10 +102,27 @@ public static void create(Client client, Map headers, String ind }); } - private static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapabilitiesResponse) { + // Visible for testing + static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapabilitiesResponse) { Set fields = fieldCapabilitiesResponse.get().keySet(); fields.removeAll(IGNORE_FIELDS); - return ExtractedFields.build(new ArrayList<>(fields), Collections.emptySet(), fieldCapabilitiesResponse) + removeFieldsWithIncompatibleTypes(fields, fieldCapabilitiesResponse); + ExtractedFields extractedFields = ExtractedFields.build(new ArrayList<>(fields), Collections.emptySet(), fieldCapabilitiesResponse) .filterFields(ExtractedField.ExtractionMethod.DOC_VALUE); + if (extractedFields.getAllFields().isEmpty()) { + throw ExceptionsHelper.badRequestException("No compatible fields could be detected"); + } + return extractedFields; + } + + private static void removeFieldsWithIncompatibleTypes(Set fields, FieldCapabilitiesResponse fieldCapabilitiesResponse) { + Iterator fieldsIterator = fields.iterator(); + while (fieldsIterator.hasNext()) { + String field = fieldsIterator.next(); + Map fieldCaps = fieldCapabilitiesResponse.getField(field); + if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) { + fieldsIterator.remove(); + } + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java new file mode 100644 index 0000000000000..1a43b2893baef --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.fieldcaps.FieldCapabilities; +import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class DataFrameDataExtractorFactoryTests extends ESTestCase { + + public void testDetectExtractedFields_GivenFloatField() { + FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float").build(); + + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities); + + List allFields = extractedFields.getAllFields(); + assertThat(allFields.size(), equalTo(1)); + assertThat(allFields.get(0).getName(), equalTo("some_float")); + } + + public void testDetectExtractedFields_GivenNumericFieldWithMultipleTypes() { + FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + .addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float") + .build(); + + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities); + + List allFields = extractedFields.getAllFields(); + assertThat(allFields.size(), equalTo(1)); + assertThat(allFields.get(0).getName(), equalTo("some_number")); + } + + public void testDetectExtractedFields_GivenNonNumericField() { + FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + .addAggregatableField("some_keyword", "keyword").build(); + + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities)); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected")); + } + + public void testDetectExtractedFields_GivenFieldWithNumericAndNonNumericTypes() { + FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + .addAggregatableField("indecisive_field", "float", "keyword").build(); + + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities)); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected")); + } + + public void testDetectExtractedFields_GivenMultipleFields() { + FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + .addAggregatableField("some_float", "float") + .addAggregatableField("some_long", "long") + .addAggregatableField("some_keyword", "keyword") + .build(); + + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities); + + List allFields = extractedFields.getAllFields(); + assertThat(allFields.size(), equalTo(2)); + assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toSet()), + containsInAnyOrder("some_float", "some_long")); + } + + public void testDetectExtractedFields_GivenIgnoredField() { + FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() + .addAggregatableField("_id", "float").build(); + + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities)); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected")); + } + + private static class MockFieldCapsResponseBuilder { + + private final Map> fieldCaps = new HashMap<>(); + + private MockFieldCapsResponseBuilder addAggregatableField(String field, String... types) { + Map caps = new HashMap<>(); + for (String type : types) { + caps.put(type, new FieldCapabilities(field, type, true, true)); + } + fieldCaps.put(field, caps); + return this; + } + + private FieldCapabilitiesResponse build() { + FieldCapabilitiesResponse response = mock(FieldCapabilitiesResponse.class); + when(response.get()).thenReturn(fieldCaps); + + for (String field : fieldCaps.keySet()) { + when(response.getField(field)).thenReturn(fieldCaps.get(field)); + } + return response; + } + } +} From 8cf53372bd5d29b0bb53d50d0989d9f611b4605d Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 30 Nov 2018 11:30:32 +0000 Subject: [PATCH 08/67] [FEATURE][ML] Skip rows that have missing data (#36067) This also prepares for allowing the `DataFrameDataExtractor` to be reused while joining the results with the raw documents. --- .../ml/analytics/DataFrameDataExtractor.java | 52 +++++++++++++------ .../process/AnalyticsProcessManager.java | 11 ++-- 2 files changed, 44 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java index c35ee278b4e8e..2a8c39150ff92 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.search.SearchScrollAction; import org.elasticsearch.action.search.SearchScrollRequestBuilder; import org.elasticsearch.client.Client; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.sort.SortOrder; @@ -68,18 +69,18 @@ public void cancel() { isCancelled = true; } - public Optional> next() throws IOException { + public Optional> next() throws IOException { if (!hasNext()) { throw new NoSuchElementException(); } - Optional> records = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll()); - if (!records.isPresent()) { + Optional> hits = scrollId == null ? Optional.ofNullable(initScroll()) : Optional.ofNullable(continueScroll()); + if (!hits.isPresent()) { hasNext = false; } - return records; + return hits; } - protected List initScroll() throws IOException { + protected List initScroll() throws IOException { LOGGER.debug("[{}] Initializing scroll", "analytics"); SearchResponse searchResponse = executeSearchRequest(buildSearchRequest()); LOGGER.debug("[{}] Search response was obtained", context.jobId); @@ -106,7 +107,7 @@ private SearchRequestBuilder buildSearchRequest() { return searchRequestBuilder; } - private List processSearchResponse(SearchResponse searchResponse) throws IOException { + private List processSearchResponse(SearchResponse searchResponse) throws IOException { if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) { LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId); @@ -123,29 +124,35 @@ private List processSearchResponse(SearchResponse searchResponse) thro } SearchHit[] hits = searchResponse.getHits().getHits(); - List records = new ArrayList<>(hits.length); + List rows = new ArrayList<>(hits.length); for (SearchHit hit : hits) { if (isCancelled) { hasNext = false; clearScroll(scrollId); break; } - records.add(toStringArray(hit)); + rows.add(createRow(hit)); } - return records; + return rows; + } - private String[] toStringArray(SearchHit hit) { - String[] result = new String[context.extractedFields.getAllFields().size()]; - for (int i = 0; i < result.length; ++i) { + private Row createRow(SearchHit hit) { + String[] extractedValues = new String[context.extractedFields.getAllFields().size()]; + for (int i = 0; i < extractedValues.length; ++i) { ExtractedField field = context.extractedFields.getAllFields().get(i); Object[] values = field.value(hit); - result[i] = (values.length == 1) ? Objects.toString(values[0]) : ""; + if (values.length == 1 && values[0] instanceof Number) { + extractedValues[i] = Objects.toString(values[0]); + } else { + extractedValues = null; + break; + } } - return result; + return new Row(extractedValues); } - private List continueScroll() throws IOException { + private List continueScroll() throws IOException { LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId); SearchResponse searchResponse = executeSearchScrollRequest(scrollId); LOGGER.debug("[{}] Search response was obtained", context.jobId); @@ -210,4 +217,19 @@ public DataSummary(long rows, long cols) { this.cols = cols; } } + + public static class Row { + + @Nullable + private String[] values; + + private Row(String[] values) { + this.values = values; + } + + @Nullable + public String[] getValues() { + return values; + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java index e0b939ad420cf..09a78014e1bc3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java @@ -48,10 +48,13 @@ public void processData(String jobId, DataFrameDataExtractor dataExtractor) { process.writeRecord(dataExtractor.getFieldNamesArray()); while (dataExtractor.hasNext()) { - Optional> records = dataExtractor.next(); - if (records.isPresent()) { - for (String[] record : records.get()) { - process.writeRecord(record); + Optional> rows = dataExtractor.next(); + if (rows.isPresent()) { + for (DataFrameDataExtractor.Row row : rows.get()) { + String[] rowValues = row.getValues(); + if (rowValues != null) { + process.writeRecord(rowValues); + } } } } From 88496c2871333a32c7566755ef71c00e119e8510 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 3 Dec 2018 16:37:57 +0000 Subject: [PATCH 09/67] [FEATURE][ML] Write control message to signify end of data (#36158) --- .../ml/analytics/DataFrameDataExtractor.java | 9 +--- .../AnalyticsControlMessageWriter.java | 38 ++++++++++++++ .../analytics/process/AnalyticsProcess.java | 9 ++++ .../process/AnalyticsProcessConfig.java | 8 ++- .../process/AnalyticsProcessManager.java | 51 +++++++++++++------ .../process/NativeAnalyticsProcess.java | 6 +++ .../NativeAnalyticsProcessFactory.java | 5 +- .../AnalyticsControlMessageWriterTests.java | 50 ++++++++++++++++++ 8 files changed, 151 insertions(+), 25 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java index 2a8c39150ff92..b0d5a032665f8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java @@ -192,11 +192,6 @@ public List getFieldNames() { return context.extractedFields.getAllFields().stream().map(ExtractedField::getAlias).collect(Collectors.toList()); } - public String[] getFieldNamesArray() { - List fieldNames = getFieldNames(); - return fieldNames.toArray(new String[fieldNames.size()]); - } - public DataSummary collectDataSummary() { SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) .setIndices(context.indices) @@ -210,9 +205,9 @@ public DataSummary collectDataSummary() { public static class DataSummary { public final long rows; - public final long cols; + public final int cols; - public DataSummary(long rows, long cols) { + public DataSummary(long rows, int cols) { this.rows = rows; this.cols = cols; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java new file mode 100644 index 0000000000000..ff51ef1122c8d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.xpack.ml.process.writer.AbstractControlMsgWriter; +import org.elasticsearch.xpack.ml.process.writer.LengthEncodedWriter; + +import java.io.IOException; + +public class AnalyticsControlMessageWriter extends AbstractControlMsgWriter { + + /** + * This must match the code defined in the api::CDataFrameAnalyzer C++ class. + * The constant there is referred as RUN_ANALYSIS_CONTROL_MESSAGE_FIELD_VALUE + * but in the context of the java side it is more descriptive to call this the + * end of data message. + */ + private static final String END_OF_DATA_MESSAGE_CODE = "r"; + + /** + * Construct the control message writer with a LengthEncodedWriter + * + * @param lengthEncodedWriter The writer + * @param numberOfFields The number of fields the process expects in each record + */ + public AnalyticsControlMessageWriter(LengthEncodedWriter lengthEncodedWriter, int numberOfFields) { + super(lengthEncodedWriter, numberOfFields); + } + + public void writeEndOfData() throws IOException { + writeMessage(END_OF_DATA_MESSAGE_CODE); + fillCommandBuffer(); + lengthEncodedWriter.flush(); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java index e9ae8ebacf86f..932d144cf6973 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java @@ -7,5 +7,14 @@ import org.elasticsearch.xpack.ml.process.NativeProcess; +import java.io.IOException; + public interface AnalyticsProcess extends NativeProcess { + + /** + * Writes a control message that informs the process + * all data has been sent + * @throws IOException If an error occurs writing to the process + */ + void writeEndOfDataMessage() throws IOException; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java index 62e97189d0dbf..4ee543ed078a4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java @@ -22,13 +22,13 @@ public class AnalyticsProcessConfig implements ToXContentObject { private static final String ANALYSIS = "analysis"; private final long rows; - private final long cols; + private final int cols; private final ByteSizeValue memoryLimit; private final int threads; private final DataFrameAnalysis analysis; - public AnalyticsProcessConfig(long rows, long cols, ByteSizeValue memoryLimit, int threads, DataFrameAnalysis analysis) { + public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, DataFrameAnalysis analysis) { this.rows = rows; this.cols = cols; this.memoryLimit = Objects.requireNonNull(memoryLimit); @@ -36,6 +36,10 @@ public AnalyticsProcessConfig(long rows, long cols, ByteSizeValue memoryLimit, i this.analysis = Objects.requireNonNull(analysis); } + public int cols() { + return cols; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java index 09a78014e1bc3..2671f49293863 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java @@ -7,6 +7,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.client.Client; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; @@ -44,27 +45,16 @@ public void processData(String jobId, DataFrameDataExtractor dataExtractor) { threadPool.generic().execute(() -> { AnalyticsProcess process = createProcess(jobId, dataExtractor); try { - // Fake header - process.writeRecord(dataExtractor.getFieldNamesArray()); - - while (dataExtractor.hasNext()) { - Optional> rows = dataExtractor.next(); - if (rows.isPresent()) { - for (DataFrameDataExtractor.Row row : rows.get()) { - String[] rowValues = row.getValues(); - if (rowValues != null) { - process.writeRecord(rowValues); - } - } - } - } + writeHeaderRecord(dataExtractor, process); + writeDataRows(dataExtractor, process); + process.writeEndOfDataMessage(); process.flushStream(); LOGGER.debug("[{}] Closing process", jobId); process.close(); LOGGER.info("[{}] Closed process", jobId); } catch (IOException e) { - + LOGGER.error(new ParameterizedMessage("[{}] Error writing data to the process", jobId), e); } finally { try { process.close(); @@ -75,6 +65,37 @@ public void processData(String jobId, DataFrameDataExtractor dataExtractor) { }); } + private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { + // The extra field is the control field (should be an empty string) + String[] record = new String[dataExtractor.getFieldNames().size() + 1]; + // The value of the control field should be an empty string for data frame rows + record[record.length - 1] = ""; + + while (dataExtractor.hasNext()) { + Optional> rows = dataExtractor.next(); + if (rows.isPresent()) { + for (DataFrameDataExtractor.Row row : rows.get()) { + String[] rowValues = row.getValues(); + if (rowValues != null) { + System.arraycopy(rowValues, 0, record, 0, rowValues.length); + process.writeRecord(record); + } + } + } + } + } + + private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { + List fieldNames = dataExtractor.getFieldNames(); + String[] headerRecord = new String[fieldNames.size() + 1]; + for (int i = 0; i < fieldNames.size(); i++) { + headerRecord[i] = fieldNames.get(i); + } + // The field name of the control field is dot + headerRecord[headerRecord.length - 1] = "."; + process.writeRecord(headerRecord); + } + private AnalyticsProcess createProcess(String jobId, DataFrameDataExtractor dataExtractor) { // TODO We should rename the thread pool to reflect its more general use now, e.g. JOB_PROCESS_THREAD_POOL_NAME ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java index 12f41f83568dd..b8a387fd6e2b2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java @@ -7,6 +7,7 @@ import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.file.Path; @@ -31,4 +32,9 @@ public String getName() { public void persistState() { // Nothing to persist } + + @Override + public void writeEndOfDataMessage() throws IOException { + new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData(); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java index 7d8d1557ad967..3943e7fd7924f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java @@ -45,10 +45,13 @@ public AnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessCon ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId, true, false, true, true, false, false); + // The extra 1 is the control field + int numberOfFields = analyticsProcessConfig.cols() + 1; + createNativeProcess(jobId, analyticsProcessConfig, filesToDelete, processPipes); NativeAnalyticsProcess analyticsProcess = new NativeAnalyticsProcess(jobId, processPipes.getLogStream().get(), - processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, 0, + processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), null, numberOfFields, filesToDelete, () -> {}); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java new file mode 100644 index 0000000000000..1adf91fa884c8 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.process.writer.LengthEncodedWriter; +import org.junit.Before; +import org.mockito.InOrder; +import org.mockito.Mockito; + +import java.io.IOException; +import java.util.stream.IntStream; + +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +public class AnalyticsControlMessageWriterTests extends ESTestCase { + + private LengthEncodedWriter lengthEncodedWriter; + + @Before + public void setUpMocks() { + lengthEncodedWriter = Mockito.mock(LengthEncodedWriter.class); + } + + public void testWriteEndOfData() throws IOException { + AnalyticsControlMessageWriter writer = new AnalyticsControlMessageWriter(lengthEncodedWriter, 4); + + writer.writeEndOfData(); + + InOrder inOrder = inOrder(lengthEncodedWriter); + inOrder.verify(lengthEncodedWriter).writeNumFields(4); + inOrder.verify(lengthEncodedWriter, times(3)).writeField(""); + inOrder.verify(lengthEncodedWriter).writeField("r"); + + StringBuilder spaces = new StringBuilder(); + IntStream.rangeClosed(1, 8192).forEach(i -> spaces.append(' ')); + inOrder.verify(lengthEncodedWriter).writeNumFields(4); + inOrder.verify(lengthEncodedWriter, times(3)).writeField(""); + inOrder.verify(lengthEncodedWriter).writeField(spaces.toString()); + + inOrder.verify(lengthEncodedWriter).flush(); + + verifyNoMoreInteractions(lengthEncodedWriter); + } +} From fd3e6a972bd658cee75904140b0bc6ecc051a5cd Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 11 Dec 2018 12:17:51 +0000 Subject: [PATCH 10/67] [FEATURE][ML] Parse results and join them in the data-frame copy index (#36382) --- .../action/TransportRunAnalyticsAction.java | 4 +- .../ml/analytics/DataFrameDataExtractor.java | 22 +- .../DataFrameDataExtractorContext.java | 4 +- .../DataFrameDataExtractorFactory.java | 6 +- .../AnalyticsControlMessageWriter.java | 2 +- .../analytics/process/AnalyticsProcess.java | 14 + .../process/AnalyticsProcessManager.java | 52 ++- .../ml/analytics/process/AnalyticsResult.java | 73 +++ .../process/AnalyticsResultProcessor.java | 115 +++++ .../process/NativeAnalyticsProcess.java | 9 + .../autodetect/NativeAutodetectProcess.java | 19 +- .../NativeAutodetectProcessFactory.java | 9 +- .../ml/process/AbstractNativeProcess.java | 12 + .../ProcessResultsParser.java} | 34 +- .../AnalyticsControlMessageWriterTests.java | 2 +- .../AnalyticsResultProcessorTests.java | 124 +++++ .../process/AnalyticsResultTests.java | 39 ++ .../NativeAutodetectProcessTests.java | 13 +- .../output/AutodetectResultsParserTests.java | 422 ------------------ .../ml/process/ProcessResultsParserTests.java | 113 +++++ 20 files changed, 594 insertions(+), 494 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{job/process/autodetect/output/AutodetectResultsParser.java => process/ProcessResultsParser.java} (72%) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java delete mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutodetectResultsParserTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java index 7d7b194de8ae6..f8f1226492941 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java @@ -38,7 +38,6 @@ import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.analytics.DataFrameFields; import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessManager; @@ -178,8 +177,7 @@ private void runPipelineAnalytics(String index, ActionListener dataExtractorFactoryListener = ActionListener.wrap( dataExtractorFactory -> { - DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(); - analyticsProcessManager.processData(jobId, dataExtractor); + analyticsProcessManager.runJob(jobId, dataExtractorFactory); listener.onResponse(new AcknowledgedResponse(true)); }, listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java index b0d5a032665f8..3d47aeff1b8b5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java @@ -98,7 +98,7 @@ private SearchRequestBuilder buildSearchRequest() { .setIndices(context.indices) .setSize(context.scrollSize) .setQuery(context.query) - .setFetchSource(false); + .setFetchSource(context.includeSource); for (ExtractedField docValueField : context.extractedFields.getDocValueFields()) { searchRequestBuilder.addDocValueField(docValueField.getName(), docValueField.getDocValueFormat()); @@ -149,7 +149,7 @@ private Row createRow(SearchHit hit) { break; } } - return new Row(extractedValues); + return new Row(extractedValues, hit); } private List continueScroll() throws IOException { @@ -196,10 +196,11 @@ public DataSummary collectDataSummary() { SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) .setIndices(context.indices) .setSize(0) - .setQuery(context.query); + .setQuery(context.query) + .setTrackTotalHits(true); SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder); - return new DataSummary(searchResponse.getHits().getTotalHits(), context.extractedFields.getAllFields().size()); + return new DataSummary(searchResponse.getHits().getTotalHits().value, context.extractedFields.getAllFields().size()); } public static class DataSummary { @@ -215,16 +216,27 @@ public DataSummary(long rows, int cols) { public static class Row { + private SearchHit hit; + @Nullable private String[] values; - private Row(String[] values) { + private Row(String[] values, SearchHit hit) { this.values = values; + this.hit = hit; } @Nullable public String[] getValues() { return values; } + + public SearchHit getHit() { + return hit; + } + + public boolean shouldSkip() { + return values == null; + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java index 2acba64197c4b..d1b52bdac0351 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java @@ -20,14 +20,16 @@ public class DataFrameDataExtractorContext { final QueryBuilder query; final int scrollSize; final Map headers; + final boolean includeSource; DataFrameDataExtractorContext(String jobId, ExtractedFields extractedFields, List indices, QueryBuilder query, int scrollSize, - Map headers) { + Map headers, boolean includeSource) { this.jobId = Objects.requireNonNull(jobId); this.extractedFields = Objects.requireNonNull(extractedFields); this.indices = indices.toArray(new String[indices.size()]); this.query = Objects.requireNonNull(query); this.scrollSize = scrollSize; this.headers = headers; + this.includeSource = includeSource; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java index 35085d282c87f..57c2b44f769c4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java @@ -63,14 +63,16 @@ private DataFrameDataExtractorFactory(Client client, String index, ExtractedFiel this.extractedFields = Objects.requireNonNull(extractedFields); } - public DataFrameDataExtractor newExtractor() { + public DataFrameDataExtractor newExtractor(boolean includeSource) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( "ml-analytics-" + index, extractedFields, Arrays.asList(index), QueryBuilders.matchAllQuery(), 1000, - Collections.emptyMap()); + Collections.emptyMap(), + includeSource + ); return new DataFrameDataExtractor(client, context); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java index ff51ef1122c8d..0500b51f85b2a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java @@ -18,7 +18,7 @@ public class AnalyticsControlMessageWriter extends AbstractControlMsgWriter { * but in the context of the java side it is more descriptive to call this the * end of data message. */ - private static final String END_OF_DATA_MESSAGE_CODE = "r"; + private static final String END_OF_DATA_MESSAGE_CODE = "$"; /** * Construct the control message writer with a LengthEncodedWriter diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java index 932d144cf6973..dc07d688a67ae 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java @@ -8,6 +8,7 @@ import org.elasticsearch.xpack.ml.process.NativeProcess; import java.io.IOException; +import java.util.Iterator; public interface AnalyticsProcess extends NativeProcess { @@ -17,4 +18,17 @@ public interface AnalyticsProcess extends NativeProcess { * @throws IOException If an error occurs writing to the process */ void writeEndOfDataMessage() throws IOException; + + /** + * @return stream of analytics results. + */ + Iterator readAnalyticsResults(); + + /** + * Read anything left in the stream before + * closing the stream otherwise if the process + * tries to write more after the close it gets + * a SIGPIPE + */ + void consumeAndCloseOutputStream(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java index 2671f49293863..bf079348c7e55 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.analytics.DataFrameAnalysis; import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractorFactory; import java.io.IOException; import java.util.List; @@ -41,28 +42,39 @@ public AnalyticsProcessManager(Client client, Environment environment, ThreadPoo this.processFactory = Objects.requireNonNull(analyticsProcessFactory); } - public void processData(String jobId, DataFrameDataExtractor dataExtractor) { + public void runJob(String jobId, DataFrameDataExtractorFactory dataExtractorFactory) { threadPool.generic().execute(() -> { - AnalyticsProcess process = createProcess(jobId, dataExtractor); - try { - writeHeaderRecord(dataExtractor, process); - writeDataRows(dataExtractor, process); - process.writeEndOfDataMessage(); - process.flushStream(); + DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); + AnalyticsProcess process = createProcess(jobId, createProcessConfig(dataExtractor)); + ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); + AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(client, dataExtractorFactory.newExtractor(true)); + executorService.execute(() -> resultProcessor.process(process)); + executorService.execute(() -> processData(jobId, dataExtractor, process, resultProcessor)); + }); + } + + private void processData(String jobId, DataFrameDataExtractor dataExtractor, AnalyticsProcess process, + AnalyticsResultProcessor resultProcessor) { + try { + writeHeaderRecord(dataExtractor, process); + writeDataRows(dataExtractor, process); + process.writeEndOfDataMessage(); + process.flushStream(); - LOGGER.debug("[{}] Closing process", jobId); + LOGGER.info("[{}] Waiting for result processor to complete", jobId); + resultProcessor.awaitForCompletion(); + LOGGER.info("[{}] Result processor has completed", jobId); + } catch (IOException e) { + LOGGER.error(new ParameterizedMessage("[{}] Error writing data to the process", jobId), e); + } finally { + LOGGER.info("[{}] Closing process", jobId); + try { process.close(); LOGGER.info("[{}] Closed process", jobId); } catch (IOException e) { - LOGGER.error(new ParameterizedMessage("[{}] Error writing data to the process", jobId), e); - } finally { - try { - process.close(); - } catch (IOException e) { - LOGGER.error("[{}] Error closing data frame analyzer process", jobId); - } + LOGGER.error("[{}] Error closing data frame analyzer process", jobId); } - }); + } } private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { @@ -75,8 +87,8 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces Optional> rows = dataExtractor.next(); if (rows.isPresent()) { for (DataFrameDataExtractor.Row row : rows.get()) { - String[] rowValues = row.getValues(); - if (rowValues != null) { + if (row.shouldSkip() == false) { + String[] rowValues = row.getValues(); System.arraycopy(rowValues, 0, record, 0, rowValues.length); process.writeRecord(record); } @@ -96,10 +108,10 @@ private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsPr process.writeRecord(headerRecord); } - private AnalyticsProcess createProcess(String jobId, DataFrameDataExtractor dataExtractor) { + private AnalyticsProcess createProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig) { // TODO We should rename the thread pool to reflect its more general use now, e.g. JOB_PROCESS_THREAD_POOL_NAME ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); - AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, createProcessConfig(dataExtractor), executorService); + AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, analyticsProcessConfig, executorService); if (process.isProcessAlive() == false) { throw ExceptionsHelper.serverError("Failed to start analytics process"); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java new file mode 100644 index 0000000000000..1f9ef71da8fb1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class AnalyticsResult implements ToXContentObject { + + public static final ParseField TYPE = new ParseField("analytics_result"); + public static final ParseField ID_HASH = new ParseField("id_hash"); + public static final ParseField RESULTS = new ParseField("results"); + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), + a -> new AnalyticsResult((String) a[0], (Map) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ID_HASH); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS); + } + + private final String idHash; + private final Map results; + + public AnalyticsResult(String idHash, Map results) { + this.idHash = Objects.requireNonNull(idHash); + this.results = Objects.requireNonNull(results); + } + + public String getIdHash() { + return idHash; + } + + public Map getResults() { + return results; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ID_HASH.getPreferredName(), idHash); + builder.field(RESULTS.getPreferredName(), results); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + + AnalyticsResult that = (AnalyticsResult) other; + return Objects.equals(idHash, that.idHash) && Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(idHash, results); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java new file mode 100644 index 0000000000000..bdb1526b1b78a --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class AnalyticsResultProcessor { + + private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); + + private final Client client; + private final DataFrameDataExtractor dataExtractor; + private List currentDataFrameRows; + private List currentResults; + private final CountDownLatch completionLatch = new CountDownLatch(1); + + public AnalyticsResultProcessor(Client client, DataFrameDataExtractor dataExtractor) { + this.client = Objects.requireNonNull(client); + this.dataExtractor = Objects.requireNonNull(dataExtractor); + } + + public void awaitForCompletion() { + try { + if (completionLatch.await(30, TimeUnit.MINUTES) == false) { + LOGGER.warn("Timeout waiting for results processor to complete"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOGGER.info("Interrupted waiting for results processor to complete"); + } + } + + public void process(AnalyticsProcess process) { + + try { + Iterator iterator = process.readAnalyticsResults(); + while (iterator.hasNext()) { + try { + AnalyticsResult result = iterator.next(); + if (dataExtractor.hasNext() == false) { + return; + } + if (currentDataFrameRows == null) { + Optional> nextBatch = dataExtractor.next(); + if (nextBatch.isPresent() == false) { + return; + } + currentDataFrameRows = nextBatch.get(); + currentResults = new ArrayList<>(currentDataFrameRows.size()); + } + currentResults.add(result); + if (currentResults.size() == currentDataFrameRows.size()) { + joinCurrentResults(); + currentDataFrameRows = null; + } + } catch (Exception e) { + LOGGER.warn("Error processing analytics result", e); + } + + } + } catch (Exception e) { + LOGGER.error("Error parsing analytics output", e); + } finally { + completionLatch.countDown(); + process.consumeAndCloseOutputStream(); + } + } + + private void joinCurrentResults() { + BulkRequest bulkRequest = new BulkRequest(); + for (int i = 0; i < currentDataFrameRows.size(); i++) { + DataFrameDataExtractor.Row row = currentDataFrameRows.get(i); + if (row.shouldSkip()) { + continue; + } + AnalyticsResult result = currentResults.get(i); + SearchHit hit = row.getHit(); + Map source = new LinkedHashMap(hit.getSourceAsMap()); + source.putAll(result.getResults()); + IndexRequest indexRequest = new IndexRequest(hit.getIndex(), hit.getType(), hit.getId()); + indexRequest.source(source); + indexRequest.opType(DocWriteRequest.OpType.INDEX); + bulkRequest.add(indexRequest); + } + if (bulkRequest.numberOfActions() > 0) { + BulkResponse bulkResponse = client.execute(BulkAction.INSTANCE, bulkRequest).actionGet(); + if (bulkResponse.hasFailures()) { + LOGGER.error("Failures while writing data frame"); + // TODO Better error handling + } + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java index b8a387fd6e2b2..5f0f58e8b7b8a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java @@ -6,17 +6,21 @@ package org.elasticsearch.xpack.ml.analytics.process; import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; +import org.elasticsearch.xpack.ml.process.ProcessResultsParser; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.file.Path; +import java.util.Iterator; import java.util.List; public class NativeAnalyticsProcess extends AbstractNativeProcess implements AnalyticsProcess { private static final String NAME = "analytics"; + private final ProcessResultsParser resultsParser = new ProcessResultsParser<>(AnalyticsResult.PARSER); + protected NativeAnalyticsProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields, List filesToDelete, Runnable onProcessCrash) { @@ -37,4 +41,9 @@ public void persistState() { public void writeEndOfDataMessage() throws IOException { new AnalyticsControlMessageWriter(recordWriter(), numberOfFields()).writeEndOfData(); } + + @Override + public Iterator readAnalyticsResults() { + return resultsParser.parseResults(processOutStream()); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java index 69ed0d66c8606..4c7c4d553bcbc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcess.java @@ -14,13 +14,13 @@ import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.ml.job.persistence.StateStreamer; -import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutodetectResultsParser; import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.FlushJobParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.ForecastParams; import org.elasticsearch.xpack.ml.job.process.autodetect.writer.AutodetectControlMsgWriter; import org.elasticsearch.xpack.ml.job.results.AutodetectResult; import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; +import org.elasticsearch.xpack.ml.process.ProcessResultsParser; import java.io.IOException; import java.io.InputStream; @@ -38,11 +38,11 @@ class NativeAutodetectProcess extends AbstractNativeProcess implements Autodetec private static final String NAME = "autodetect"; - private final AutodetectResultsParser resultsParser; + private final ProcessResultsParser resultsParser; NativeAutodetectProcess(String jobId, InputStream logStream, OutputStream processInStream, InputStream processOutStream, OutputStream processRestoreStream, int numberOfFields, List filesToDelete, - AutodetectResultsParser resultsParser, Runnable onProcessCrash) { + ProcessResultsParser resultsParser, Runnable onProcessCrash) { super(jobId, logStream, processInStream, processOutStream, processRestoreStream, numberOfFields, filesToDelete, onProcessCrash); this.resultsParser = resultsParser; } @@ -117,17 +117,4 @@ public Iterator readAutodetectResults() { private AutodetectControlMsgWriter newMessageWriter() { return new AutodetectControlMsgWriter(recordWriter(), numberOfFields()); } - - @Override - public void consumeAndCloseOutputStream() { - try { - byte[] buff = new byte[512]; - while (processOutStream().read(buff) >= 0) { - // Do nothing - } - processOutStream().close(); - } catch (IOException e) { - throw new RuntimeException("Error closing result parser input stream", e); - } - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java index 3185ebc6f1c7d..68aaa8a81c4b6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessFactory.java @@ -16,11 +16,12 @@ import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.process.NativeController; -import org.elasticsearch.xpack.ml.process.ProcessPipes; -import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutodetectResultsParser; import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutodetectStateProcessor; import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams; +import org.elasticsearch.xpack.ml.job.results.AutodetectResult; +import org.elasticsearch.xpack.ml.process.NativeController; +import org.elasticsearch.xpack.ml.process.ProcessPipes; +import org.elasticsearch.xpack.ml.process.ProcessResultsParser; import org.elasticsearch.xpack.ml.utils.NamedPipeHelper; import java.io.IOException; @@ -68,7 +69,7 @@ public AutodetectProcess createAutodetectProcess(Job job, int numberOfFields = job.allInputFields().size() + (includeTokensField ? 1 : 0) + 1; AutodetectStateProcessor stateProcessor = new AutodetectStateProcessor(client, job.getId()); - AutodetectResultsParser resultsParser = new AutodetectResultsParser(); + ProcessResultsParser resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER); NativeAutodetectProcess autodetect = new NativeAutodetectProcess( job.getId(), processPipes.getLogStream().get(), processPipes.getProcessInStream().get(), processPipes.getProcessOutStream().get(), processPipes.getRestoreStream().orElse(null), numberOfFields, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java index b84bfdd38e19a..8325d09b47050 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java @@ -262,4 +262,16 @@ protected LengthEncodedWriter recordWriter() { protected boolean isProcessKilled() { return processKilled; } + + public void consumeAndCloseOutputStream() { + try { + byte[] buff = new byte[512]; + while (processOutStream().read(buff) >= 0) { + // Do nothing + } + processOutStream().close(); + } catch (IOException e) { + throw new RuntimeException("Error closing result parser input stream", e); + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutodetectResultsParser.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/ProcessResultsParser.java similarity index 72% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutodetectResultsParser.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/ProcessResultsParser.java index 2ec37a0f86e5d..609c45659dd6c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutodetectResultsParser.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/ProcessResultsParser.java @@ -3,31 +3,41 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.job.process.autodetect.output; +package org.elasticsearch.xpack.ml.process; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.xpack.ml.job.results.AutodetectResult; import java.io.IOException; import java.io.InputStream; import java.util.Iterator; +import java.util.Objects; /** - * Parses the JSON output of the autodetect program. + * Parses the JSON output of a process. *

- * Expects an array of buckets so the first element will always be the + * Expects an array of objects so the first element will always be the * start array symbol and the data must be terminated with the end array symbol. */ -public class AutodetectResultsParser { - public Iterator parseResults(InputStream in) throws ElasticsearchParseException { +public class ProcessResultsParser { + + private static final Logger logger = LogManager.getLogger(ProcessResultsParser.class); + + private final ConstructingObjectParser resultParser; + + public ProcessResultsParser(ConstructingObjectParser resultParser) { + this.resultParser = Objects.requireNonNull(resultParser); + } + + public Iterator parseResults(InputStream in) throws ElasticsearchParseException { try { XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, in); @@ -36,21 +46,19 @@ public Iterator parseResults(InputStream in) throws Elasticsea if (token != XContentParser.Token.START_ARRAY) { throw new ElasticsearchParseException("unexpected token [" + token + "]"); } - return new AutodetectResultIterator(in, parser); + return new ResultIterator(in, parser); } catch (IOException e) { throw new ElasticsearchParseException(e.getMessage(), e); } } - private static class AutodetectResultIterator implements Iterator { - - private static final Logger logger = LogManager.getLogger(AutodetectResultIterator.class); + private class ResultIterator implements Iterator { private final InputStream in; private final XContentParser parser; private XContentParser.Token token; - private AutodetectResultIterator(InputStream in, XContentParser parser) { + private ResultIterator(InputStream in, XContentParser parser) { this.in = in; this.parser = parser; token = parser.currentToken(); @@ -74,8 +82,8 @@ public boolean hasNext() { } @Override - public AutodetectResult next() { - return AutodetectResult.PARSER.apply(parser, null); + public T next() { + return resultParser.apply(parser, null); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java index 1adf91fa884c8..3845b9df26b4c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java @@ -35,7 +35,7 @@ public void testWriteEndOfData() throws IOException { InOrder inOrder = inOrder(lengthEncodedWriter); inOrder.verify(lengthEncodedWriter).writeNumFields(4); inOrder.verify(lengthEncodedWriter, times(3)).writeField(""); - inOrder.verify(lengthEncodedWriter).writeField("r"); + inOrder.verify(lengthEncodedWriter).writeField("$"); StringBuilder spaces = new StringBuilder(); IntStream.rangeClosed(1, 8192).forEach(i -> spaces.append(' ')); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java new file mode 100644 index 0000000000000..3d71b4d430e2b --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AnalyticsResultProcessorTests extends ESTestCase { + + private Client client; + private AnalyticsProcess process; + private DataFrameDataExtractor dataExtractor; + private ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + + @Before + public void setUpMocks() { + client = mock(Client.class); + process = mock(AnalyticsProcess.class); + dataExtractor = mock(DataFrameDataExtractor.class); + } + + public void testProcess_GivenNoResults() { + givenProcessResults(Collections.emptyList()); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verifyNoMoreInteractions(client); + } + + public void testProcess_GivenSingleRowAndResult() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row row = newRow(newHit("1", dataDoc), dataValues); + givenSingleDataFrameBatch(Arrays.asList(row)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + AnalyticsResult result = new AnalyticsResult("some_hash", resultFields); + givenProcessResults(Arrays.asList(result)); + + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + List capturedBulkRequests = bulkRequestCaptor.getAllValues(); + assertThat(capturedBulkRequests.size(), equalTo(1)); + BulkRequest capturedBulkRequest = capturedBulkRequests.get(0); + assertThat(capturedBulkRequest.numberOfActions(), equalTo(1)); + IndexRequest indexRequest = (IndexRequest) capturedBulkRequest.requests().get(0); + Map indexedDocSource = indexRequest.sourceAsMap(); + assertThat(indexedDocSource.size(), equalTo(4)); + assertThat(indexedDocSource.get("f_1"), equalTo("foo")); + assertThat(indexedDocSource.get("f_2"), equalTo(42.0)); + assertThat(indexedDocSource.get("a"), equalTo("1")); + assertThat(indexedDocSource.get("b"), equalTo("2")); + } + + private void givenProcessResults(List results) { + when(process.readAnalyticsResults()).thenReturn(results.iterator()); + } + + private void givenSingleDataFrameBatch(List batch) throws IOException { + when(dataExtractor.hasNext()).thenReturn(true).thenReturn(true).thenReturn(false); + when(dataExtractor.next()).thenReturn(Optional.of(batch)).thenReturn(Optional.empty()); + } + + private static SearchHit newHit(String id, String json) { + SearchHit hit = new SearchHit(42, id, new Text("doc"), Collections.emptyMap()); + hit.sourceRef(new BytesArray(json)); + return hit; + } + + private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values) { + DataFrameDataExtractor.Row row = mock(DataFrameDataExtractor.Row.class); + when(row.getHit()).thenReturn(hit); + when(row.getValues()).thenReturn(values); + return row; + } + + private void givenClientHasNoFailures() { + ActionFuture responseFuture = mock(ActionFuture.class); + when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0)); + when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture); + } + + private AnalyticsResultProcessor createResultProcessor() { + return new AnalyticsResultProcessor(client, dataExtractor); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java new file mode 100644 index 0000000000000..6250a96cd3284 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.analytics.process; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class AnalyticsResultTests extends AbstractXContentTestCase { + + @Override + protected AnalyticsResult createTestInstance() { + String idHash = randomAlphaOfLength(20); + Map results = new HashMap<>(); + int resultsSize = randomIntBetween(1, 10); + for (int i = 0; i < resultsSize; i++) { + String resultField = randomAlphaOfLength(20); + Object resultValue = randomBoolean() ? randomAlphaOfLength(20) : randomDouble(); + results.put(resultField, resultValue); + } + return new AnalyticsResult(idHash, results); + } + + @Override + protected AnalyticsResult doParseInstance(XContentParser parser) throws IOException { + return AnalyticsResult.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java index 8542061c761a2..0b7a0273345d3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/NativeAutodetectProcessTests.java @@ -7,12 +7,13 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.job.config.ModelPlotConfig; -import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutodetectResultsParser; import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutodetectStateProcessor; import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.FlushJobParams; import org.elasticsearch.xpack.ml.job.process.autodetect.params.TimeRange; import org.elasticsearch.xpack.ml.job.process.autodetect.writer.AutodetectControlMsgWriter; +import org.elasticsearch.xpack.ml.job.results.AutodetectResult; +import org.elasticsearch.xpack.ml.process.ProcessResultsParser; import org.junit.Assert; import org.junit.Before; @@ -58,7 +59,7 @@ public void testProcessStartTime() throws Exception { try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, mock(OutputStream.class), outputStream, mock(OutputStream.class), NUMBER_FIELDS, null, - new AutodetectResultsParser(), mock(Runnable.class))) { + new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Runnable.class))) { process.start(executorService, mock(AutodetectStateProcessor.class), mock(InputStream.class)); ZonedDateTime startTime = process.getProcessStartTime(); @@ -80,7 +81,7 @@ public void testWriteRecord() throws IOException { ByteArrayOutputStream bos = new ByteArrayOutputStream(1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(), - new AutodetectResultsParser(), mock(Runnable.class))) { + new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Runnable.class))) { process.start(executorService, mock(AutodetectStateProcessor.class), mock(InputStream.class)); process.writeRecord(record); @@ -114,7 +115,7 @@ public void testFlush() throws IOException { ByteArrayOutputStream bos = new ByteArrayOutputStream(AutodetectControlMsgWriter.FLUSH_SPACES_LENGTH + 1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(), - new AutodetectResultsParser(), mock(Runnable.class))) { + new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Runnable.class))) { process.start(executorService, mock(AutodetectStateProcessor.class), mock(InputStream.class)); FlushJobParams params = FlushJobParams.builder().build(); @@ -147,7 +148,7 @@ public void testConsumeAndCloseOutputStream() throws IOException { try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, processInStream, processOutStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(), - new AutodetectResultsParser(), mock(Runnable.class))) { + new ProcessResultsParser(AutodetectResult.PARSER), mock(Runnable.class))) { process.consumeAndCloseOutputStream(); assertThat(processOutStream.available(), equalTo(0)); @@ -162,7 +163,7 @@ private void testWriteMessage(CheckedConsumer writeFunc ByteArrayOutputStream bos = new ByteArrayOutputStream(1024); try (NativeAutodetectProcess process = new NativeAutodetectProcess("foo", logStream, bos, outputStream, mock(OutputStream.class), NUMBER_FIELDS, Collections.emptyList(), - new AutodetectResultsParser(), mock(Runnable.class))) { + new ProcessResultsParser<>(AutodetectResult.PARSER), mock(Runnable.class))) { process.start(executorService, mock(AutodetectStateProcessor.class), mock(InputStream.class)); writeFunction.accept(process); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutodetectResultsParserTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutodetectResultsParserTests.java deleted file mode 100644 index 1118453154ed8..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutodetectResultsParserTests.java +++ /dev/null @@ -1,422 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.job.process.autodetect.output; - -import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.common.xcontent.XContentParseException; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.Quantiles; -import org.elasticsearch.xpack.core.ml.job.results.Bucket; -import org.elasticsearch.xpack.core.ml.job.results.BucketInfluencer; -import org.elasticsearch.xpack.ml.job.results.AutodetectResult; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.stream.Collectors; - -/** - * Tests for parsing the JSON output of autodetect - */ -public class AutodetectResultsParserTests extends ESTestCase { - private static final double EPSILON = 0.000001; - - private static final String METRIC_OUTPUT_SAMPLE = "[{\"bucket\": {\"job_id\":\"foo\",\"timestamp\":1359450000000," - + "\"bucket_span\":22, \"records\":[]," - + "\"anomaly_score\":0,\"event_count\":806,\"bucket_influencers\":[" - + "{\"timestamp\":1359450000000,\"bucket_span\":22,\"job_id\":\"foo\",\"anomaly_score\":0," - + "\"probability\":0.0, \"influencer_field_name\":\"bucket_time\"," - + "\"initial_anomaly_score\":0.0}]}},{\"quantiles\": {\"job_id\":\"foo\", \"quantile_state\":\"[normalizer 1.1, normalizer 2" + - ".1]\",\"timestamp\":1359450000000}}" - + ",{\"bucket\": {\"job_id\":\"foo\",\"timestamp\":1359453600000,\"bucket_span\":22,\"records\":" - + "[{\"timestamp\":1359453600000,\"bucket_span\":22,\"job_id\":\"foo\",\"probability\":0.0637541," - + "\"by_field_name\":\"airline\",\"by_field_value\":\"JZA\", \"typical\":[1020.08],\"actual\":[1042.14]," - + "\"field_name\":\"responsetime\",\"function\":\"max\",\"partition_field_name\":\"\",\"partition_field_value\":\"\"}," - + "{\"timestamp\":1359453600000,\"bucket_span\":22,\"job_id\":\"foo\",\"probability\":0.00748292," - + "\"by_field_name\":\"airline\",\"by_field_value\":\"AMX\", " - + "\"typical\":[20.2137],\"actual\":[22.8855],\"field_name\":\"responsetime\",\"function\":\"max\"," - + "\"partition_field_name\":\"\",\"partition_field_value\":\"\"},{\"timestamp\":1359453600000,\"bucket_span\":22," - + "\"job_id\":\"foo\",\"probability\":0.023494,\"by_field_name\":\"airline\"," - + "\"by_field_value\":\"DAL\", \"typical\":[382.177],\"actual\":[358.934],\"field_name\":\"responsetime\",\"function\":\"min\"," - + "\"partition_field_name\":\"\", \"partition_field_value\":\"\"},{\"timestamp\":1359453600000,\"bucket_span\":22," - + "\"job_id\":\"foo\"," - + "\"probability\":0.0473552,\"by_field_name\":\"airline\",\"by_field_value\":\"SWA\", \"typical\":[152.148]," - + "\"actual\":[96.6425],\"field_name\":\"responsetime\",\"function\":\"min\",\"partition_field_name\":\"\"," - + "\"partition_field_value\":\"\"}]," - + "\"initial_anomaly_score\":0.0140005, \"anomaly_score\":20.22688," - + "\"event_count\":820,\"bucket_influencers\":[{\"timestamp\":1359453600000,\"bucket_span\":22," - + "\"job_id\":\"foo\", \"raw_anomaly_score\":0.0140005, \"probability\":0.01,\"influencer_field_name\":\"bucket_time\"," - + "\"initial_anomaly_score\":20.22688,\"anomaly_score\":20.22688} ,{\"timestamp\":1359453600000,\"bucket_span\":22," - + "\"job_id\":\"foo\",\"raw_anomaly_score\":0.005, \"probability\":0.03," - + "\"influencer_field_name\":\"foo\",\"initial_anomaly_score\":10.5,\"anomaly_score\":10.5}]}},{\"quantiles\": " - + "{\"job_id\":\"foo\",\"timestamp\":1359453600000," - + "\"quantile_state\":\"[normalizer 1.2, normalizer 2.2]\"}} ,{\"flush\": {\"id\":\"testing1\"}} ," - + "{\"quantiles\": {\"job_id\":\"foo\",\"timestamp\":1359453600000,\"quantile_state\":\"[normalizer 1.3, normalizer 2.3]\"}} ]"; - - private static final String POPULATION_OUTPUT_SAMPLE = "[{\"timestamp\":1379590200,\"records\":[{\"probability\":1.38951e-08," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"mail.google.com\"," - + "\"function\":\"max\"," - + "\"causes\":[{\"probability\":1.38951e-08,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"mail.google.com\",\"function\":\"max\",\"typical\":[101534],\"actual\":[9.19027e+07]}]," - + "\"record_score\":100,\"anomaly_score\":44.7324},{\"probability\":3.86587e-07,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"armmf.adobe.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":3.86587e-07,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"armmf.adobe.com\",\"function\":\"max\",\"typical\":[101534],\"actual\":[3.20093e+07]}]," - + "\"record_score\":89.5834,\"anomaly_score\":44.7324},{\"probability\":0.00500083,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"0.docs.google.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.00500083,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"0.docs.google.com\",\"function\":\"max\",\"typical\":[101534],\"actual\":[6.61812e+06]}]," - + "\"record_score\":1.19856,\"anomaly_score\":44.7324},{\"probability\":0.0152333,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"emea.salesforce.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.0152333,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"emea.salesforce.com\",\"function\":\"max\",\"typical\":[101534],\"actual\":[5.36373e+06]}]," - + "\"record_score\":0.303996,\"anomaly_score\":44.7324}],\"raw_anomaly_score\":1.30397,\"anomaly_score\":44.7324," - + "\"event_count\":1235}" + ",{\"flush\":\"testing2\"}" - + ",{\"timestamp\":1379590800,\"records\":[{\"probability\":1.9008e-08,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"mail.google.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":1.9008e-08,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"mail.google.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[1.1498e+08]}]," - + "\"record_score\":93.6213,\"anomaly_score\":1.19192},{\"probability\":1.01013e-06,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"armmf.adobe.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":1.01013e-06,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"armmf.adobe.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[3.25808e+07]}]," - + "\"record_score\":86.5825,\"anomaly_score\":1.19192},{\"probability\":0.000386185,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"0.docs.google.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.000386185,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"0.docs.google.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[3.22855e+06]}]," - + "\"record_score\":17.1179,\"anomaly_score\":1.19192},{\"probability\":0.00208033,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"docs.google.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.00208033,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"docs.google.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[1.43328e+06]}]," - + "\"record_score\":3.0692,\"anomaly_score\":1.19192},{\"probability\":0.00312988,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"booking2.airasia.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.00312988,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"booking2.airasia.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[1.15764e+06]}]," - + "\"record_score\":1.99532,\"anomaly_score\":1.19192},{\"probability\":0.00379229,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.facebook.com\",\"function\":\"max\",\"causes\":[" - + "{\"probability\":0.00379229,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.facebook.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[1.0443e+06]}]," - + "\"record_score\":1.62352,\"anomaly_score\":1.19192},{\"probability\":0.00623576,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.airasia.com\",\"function\":\"max\",\"causes\":[" - + "{\"probability\":0.00623576,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.airasia.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[792699]}]," - + "\"record_score\":0.935134,\"anomaly_score\":1.19192},{\"probability\":0.00665308,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.google.com\",\"function\":\"max\",\"causes\":[" - + "{\"probability\":0.00665308,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.google.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[763985]}]," - + "\"record_score\":0.868119,\"anomaly_score\":1.19192},{\"probability\":0.00709315,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"0.drive.google.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.00709315,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"0.drive.google.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[736442]}]," - + "\"record_score\":0.805994,\"anomaly_score\":1.19192},{\"probability\":0.00755789,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"resources2.news.com.au\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.00755789,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"resources2.news.com.au\",\"function\":\"max\",\"typical\":[31356],\"actual\":[709962]}]," - + "\"record_score\":0.748239,\"anomaly_score\":1.19192},{\"probability\":0.00834974,\"field_name\":" - + "\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.calypso.net.au\",\"function\":\"max\"," - + "\"causes\":[{\"probability\":0.00834974,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.calypso.net.au\",\"function\":\"max\",\"typical\":[31356],\"actual\":[669968]}]," - + "\"record_score\":0.664644,\"anomaly_score\":1.19192},{\"probability\":0.0107711,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"ad.yieldmanager.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.0107711,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"ad.yieldmanager.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[576067]}]," - + "\"record_score\":0.485277,\"anomaly_score\":1.19192},{\"probability\":0.0123367,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.google-analytics.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.0123367,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.google-analytics.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[530594]}]," - + "\"record_score\":0.406783,\"anomaly_score\":1.19192},{\"probability\":0.0125647,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"bs.serving-sys.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.0125647,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"bs.serving-sys.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[524690]}]," - + "\"record_score\":0.396986,\"anomaly_score\":1.19192},{\"probability\":0.0141652,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.google.com.au\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.0141652,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.google.com.au\",\"function\":\"max\",\"typical\":[31356],\"actual\":[487328]}]," - + "\"record_score\":0.337075,\"anomaly_score\":1.19192},{\"probability\":0.0141742,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"resources1.news.com.au\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.0141742,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"resources1.news.com.au\",\"function\":\"max\",\"typical\":[31356],\"actual\":[487136]}]," - + "\"record_score\":0.336776,\"anomaly_score\":1.19192},{\"probability\":0.0145263,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"b.mail.google.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.0145263,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"b.mail.google.com\",\"function\":\"max\",\"typical\":[31356],\"actual\":[479766]}]," - + "\"record_score\":0.325385,\"anomaly_score\":1.19192},{\"probability\":0.0151447,\"field_name\":\"sum_cs_bytes_\"," - + "\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.rei.com\",\"function\":\"max\",\"causes\":[{" - + "\"probability\":0.0151447,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.rei" + - ".com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[467450]}],\"record_score\":0.306657,\"anomaly_score\":1" + - ".19192}," - + "{\"probability\":0.0164073,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"s3.amazonaws.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0164073," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"s3.amazonaws.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[444511]}],\"record_score\":0.272805,\"anomaly_score\":1" + - ".19192}," - + "{\"probability\":0.0201927,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"0-p-06-ash2.channel.facebook.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0201927," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"0-p-06-ash2.channel.facebook.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[389243]}],\"record_score\":0.196685,\"anomaly_score\":1" + - ".19192}," - + "{\"probability\":0.0218721,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"booking.airasia.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0218721," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"booking.airasia.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[369509]}],\"record_score\":0.171353," - + "\"anomaly_score\":1.19192},{\"probability\":0.0242411,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.yammer.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0242411," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.yammer.com\"," + - "\"function\":\"max\"," - + "\"typical\":[31356],\"actual\":[345295]}],\"record_score\":0.141585,\"anomaly_score\":1.19192}," - + "{\"probability\":0.0258232,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"safebrowsing-cache.google.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0258232," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"safebrowsing-cache.google.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[331051]}],\"record_score\":0.124748,\"anomaly_score\":1" + - ".19192}," - + "{\"probability\":0.0259695,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"fbcdn-profile-a.akamaihd.net\",\"function\":\"max\",\"causes\":[{\"probability\":0.0259695," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"fbcdn-profile-a.akamaihd.net\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[329801]}],\"record_score\":0.123294,\"anomaly_score\":1" + - ".19192}," - + "{\"probability\":0.0268874,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.oag.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0268874," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.oag.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[322200]}],\"record_score\":0.114537," - + "\"anomaly_score\":1.19192},{\"probability\":0.0279146,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"booking.qatarairways.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0279146," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"booking.qatarairways.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[314153]}],\"record_score\":0.105419,\"anomaly_score\":1" + - ".19192}," - + "{\"probability\":0.0309351,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"resources3.news.com.au\",\"function\":\"max\",\"causes\":[{\"probability\":0.0309351," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"resources3.news.com.au\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[292918]}],\"record_score\":0.0821156,\"anomaly_score\":1" + - ".19192}" - + ",{\"probability\":0.0335204,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"resources0.news.com.au\",\"function\":\"max\",\"causes\":[{\"probability\":0.0335204," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"resources0.news.com.au\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[277136]}],\"record_score\":0.0655063,\"anomaly_score\":1" + - ".19192}" - + ",{\"probability\":0.0354927,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.southwest.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0354927," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.southwest.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[266310]}],\"record_score\":0.0544615," - + "\"anomaly_score\":1.19192},{\"probability\":0.0392043,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"syndication.twimg.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0392043," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"syndication.twimg.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[248276]}],\"record_score\":0.0366913,\"anomaly_score\":1" + - ".19192}" - + ",{\"probability\":0.0400853,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"" - + ",\"over_field_value\":\"mts0.google.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0400853," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"mts0.google.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[244381]}],\"record_score\":0.0329562," - + "\"anomaly_score\":1.19192},{\"probability\":0.0407335,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"www.onthegotours.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0407335," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"www.onthegotours.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[241600]}],\"record_score\":0.0303116," - + "\"anomaly_score\":1.19192},{\"probability\":0.0470889,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"chatenabled.mail.google.com\",\"function\":\"max\",\"causes\":[{\"probability\":0.0470889," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"chatenabled.mail.google.com\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[217573]}],\"record_score\":0.00823738," - + "\"anomaly_score\":1.19192},{\"probability\":0.0491243,\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\"," - + "\"over_field_value\":\"googleads.g.doubleclick.net\",\"function\":\"max\",\"causes\":[{\"probability\":0.0491243," - + "\"field_name\":\"sum_cs_bytes_\",\"over_field_name\":\"cs_host\",\"over_field_value\":\"googleads.g.doubleclick.net\"," - + "\"function\":\"max\",\"typical\":[31356],\"actual\":[210926]}],\"record_score\":0.00237509," - + "\"anomaly_score\":1.19192}],\"raw_anomaly_score\":1.26918,\"anomaly_score\":1.19192," - + "\"event_count\":1159}" + "]"; - - public void testParser() throws IOException { - try (InputStream inputStream = new ByteArrayInputStream(METRIC_OUTPUT_SAMPLE.getBytes(StandardCharsets.UTF_8))) { - AutodetectResultsParser parser = new AutodetectResultsParser(); - List results = new ArrayList<>(); - parser.parseResults(inputStream).forEachRemaining(results::add); - List buckets = results.stream().map(AutodetectResult::getBucket) - .filter(b -> b != null) - .collect(Collectors.toList()); - - assertEquals(2, buckets.size()); - assertEquals(new Date(1359450000000L), buckets.get(0).getTimestamp()); - - assertEquals(buckets.get(0).getEventCount(), 806); - - List bucketInfluencers = buckets.get(0).getBucketInfluencers(); - assertEquals(1, bucketInfluencers.size()); - assertEquals(0.0, bucketInfluencers.get(0).getRawAnomalyScore(), EPSILON); - assertEquals(0.0, bucketInfluencers.get(0).getAnomalyScore(), EPSILON); - assertEquals(0.0, bucketInfluencers.get(0).getProbability(), EPSILON); - assertEquals("bucket_time", bucketInfluencers.get(0).getInfluencerFieldName()); - - assertEquals(new Date(1359453600000L), buckets.get(1).getTimestamp()); - - assertEquals(buckets.get(1).getEventCount(), 820); - bucketInfluencers = buckets.get(1).getBucketInfluencers(); - assertEquals(2, bucketInfluencers.size()); - assertEquals(0.0140005, bucketInfluencers.get(0).getRawAnomalyScore(), EPSILON); - assertEquals(20.22688, bucketInfluencers.get(0).getAnomalyScore(), EPSILON); - assertEquals(0.01, bucketInfluencers.get(0).getProbability(), EPSILON); - assertEquals("bucket_time", bucketInfluencers.get(0).getInfluencerFieldName()); - assertEquals(0.005, bucketInfluencers.get(1).getRawAnomalyScore(), EPSILON); - assertEquals(10.5, bucketInfluencers.get(1).getAnomalyScore(), EPSILON); - assertEquals(0.03, bucketInfluencers.get(1).getProbability(), EPSILON); - assertEquals("foo", bucketInfluencers.get(1).getInfluencerFieldName()); - - Bucket secondBucket = buckets.get(1); - - assertEquals(0.0637541, secondBucket.getRecords().get(0).getProbability(), EPSILON); - assertEquals("airline", secondBucket.getRecords().get(0).getByFieldName()); - assertEquals("JZA", secondBucket.getRecords().get(0).getByFieldValue()); - assertEquals(1020.08, secondBucket.getRecords().get(0).getTypical().get(0), EPSILON); - assertEquals(1042.14, secondBucket.getRecords().get(0).getActual().get(0), EPSILON); - assertEquals("responsetime", secondBucket.getRecords().get(0).getFieldName()); - assertEquals("max", secondBucket.getRecords().get(0).getFunction()); - assertEquals("", secondBucket.getRecords().get(0).getPartitionFieldName()); - assertEquals("", secondBucket.getRecords().get(0).getPartitionFieldValue()); - - assertEquals(0.00748292, secondBucket.getRecords().get(1).getProbability(), EPSILON); - assertEquals("airline", secondBucket.getRecords().get(1).getByFieldName()); - assertEquals("AMX", secondBucket.getRecords().get(1).getByFieldValue()); - assertEquals(20.2137, secondBucket.getRecords().get(1).getTypical().get(0), EPSILON); - assertEquals(22.8855, secondBucket.getRecords().get(1).getActual().get(0), EPSILON); - assertEquals("responsetime", secondBucket.getRecords().get(1).getFieldName()); - assertEquals("max", secondBucket.getRecords().get(1).getFunction()); - assertEquals("", secondBucket.getRecords().get(1).getPartitionFieldName()); - assertEquals("", secondBucket.getRecords().get(1).getPartitionFieldValue()); - - assertEquals(0.023494, secondBucket.getRecords().get(2).getProbability(), EPSILON); - assertEquals("airline", secondBucket.getRecords().get(2).getByFieldName()); - assertEquals("DAL", secondBucket.getRecords().get(2).getByFieldValue()); - assertEquals(382.177, secondBucket.getRecords().get(2).getTypical().get(0), EPSILON); - assertEquals(358.934, secondBucket.getRecords().get(2).getActual().get(0), EPSILON); - assertEquals("responsetime", secondBucket.getRecords().get(2).getFieldName()); - assertEquals("min", secondBucket.getRecords().get(2).getFunction()); - assertEquals("", secondBucket.getRecords().get(2).getPartitionFieldName()); - assertEquals("", secondBucket.getRecords().get(2).getPartitionFieldValue()); - - assertEquals(0.0473552, secondBucket.getRecords().get(3).getProbability(), EPSILON); - assertEquals("airline", secondBucket.getRecords().get(3).getByFieldName()); - assertEquals("SWA", secondBucket.getRecords().get(3).getByFieldValue()); - assertEquals(152.148, secondBucket.getRecords().get(3).getTypical().get(0), EPSILON); - assertEquals(96.6425, secondBucket.getRecords().get(3).getActual().get(0), EPSILON); - assertEquals("responsetime", secondBucket.getRecords().get(3).getFieldName()); - assertEquals("min", secondBucket.getRecords().get(3).getFunction()); - assertEquals("", secondBucket.getRecords().get(3).getPartitionFieldName()); - assertEquals("", secondBucket.getRecords().get(3).getPartitionFieldValue()); - - List quantiles = results.stream().map(AutodetectResult::getQuantiles) - .filter(q -> q != null) - .collect(Collectors.toList()); - assertEquals(3, quantiles.size()); - assertEquals("foo", quantiles.get(0).getJobId()); - assertEquals(new Date(1359450000000L), quantiles.get(0).getTimestamp()); - assertEquals("[normalizer 1.1, normalizer 2.1]", quantiles.get(0).getQuantileState()); - assertEquals("foo", quantiles.get(1).getJobId()); - assertEquals(new Date(1359453600000L), quantiles.get(1).getTimestamp()); - assertEquals("[normalizer 1.2, normalizer 2.2]", quantiles.get(1).getQuantileState()); - assertEquals("foo", quantiles.get(2).getJobId()); - assertEquals(new Date(1359453600000L), quantiles.get(2).getTimestamp()); - assertEquals("[normalizer 1.3, normalizer 2.3]", quantiles.get(2).getQuantileState()); - } - } - - @AwaitsFix(bugUrl = "rewrite this test so it doesn't use ~200 lines of json") - public void testPopulationParser() throws IOException { - try (InputStream inputStream = new ByteArrayInputStream(POPULATION_OUTPUT_SAMPLE.getBytes(StandardCharsets.UTF_8))) { - AutodetectResultsParser parser = new AutodetectResultsParser(); - List results = new ArrayList<>(); - parser.parseResults(inputStream).forEachRemaining(results::add); - List buckets = results.stream().map(AutodetectResult::getBucket) - .filter(b -> b != null) - .collect(Collectors.toList()); - - assertEquals(2, buckets.size()); - assertEquals(new Date(1379590200000L), buckets.get(0).getTimestamp()); - assertEquals(buckets.get(0).getEventCount(), 1235); - - Bucket firstBucket = buckets.get(0); - assertEquals(1.38951e-08, firstBucket.getRecords().get(0).getProbability(), EPSILON); - assertEquals("sum_cs_bytes_", firstBucket.getRecords().get(0).getFieldName()); - assertEquals("max", firstBucket.getRecords().get(0).getFunction()); - assertEquals("cs_host", firstBucket.getRecords().get(0).getOverFieldName()); - assertEquals("mail.google.com", firstBucket.getRecords().get(0).getOverFieldValue()); - assertNotNull(firstBucket.getRecords().get(0).getCauses()); - - assertEquals(new Date(1379590800000L), buckets.get(1).getTimestamp()); - assertEquals(buckets.get(1).getEventCount(), 1159); - } - } - - public void testParse_GivenEmptyArray() throws ElasticsearchParseException, IOException { - String json = "[]"; - try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { - AutodetectResultsParser parser = new AutodetectResultsParser(); - assertFalse(parser.parseResults(inputStream).hasNext()); - } - } - - public void testParse_GivenModelSizeStats() throws ElasticsearchParseException, IOException { - String json = "[{\"model_size_stats\": {\"job_id\": \"foo\", \"model_bytes\":300}}]"; - try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { - - AutodetectResultsParser parser = new AutodetectResultsParser(); - List results = new ArrayList<>(); - parser.parseResults(inputStream).forEachRemaining(results::add); - - assertEquals(1, results.size()); - assertEquals(300, results.get(0).getModelSizeStats().getModelBytes()); - } - } - - public void testParse_GivenCategoryDefinition() throws IOException { - String json = "[{\"category_definition\": {\"job_id\":\"foo\", \"category_id\":18}}]"; - try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { - AutodetectResultsParser parser = new AutodetectResultsParser(); - List results = new ArrayList<>(); - parser.parseResults(inputStream).forEachRemaining(results::add); - - - assertEquals(1, results.size()); - assertEquals(18, results.get(0).getCategoryDefinition().getCategoryId()); - } - } - - public void testParse_GivenUnknownObject() throws ElasticsearchParseException, IOException { - String json = "[{\"unknown\":{\"id\": 18}}]"; - try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { - AutodetectResultsParser parser = new AutodetectResultsParser(); - XContentParseException e = expectThrows(XContentParseException.class, - () -> parser.parseResults(inputStream).forEachRemaining(a -> { - })); - assertEquals("[1:3] [autodetect_result] unknown field [unknown], parser not found", e.getMessage()); - } - } - - public void testParse_GivenArrayContainsAnotherArray() throws ElasticsearchParseException, IOException { - String json = "[[]]"; - try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { - AutodetectResultsParser parser = new AutodetectResultsParser(); - ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, - () -> parser.parseResults(inputStream).forEachRemaining(a -> { - })); - assertEquals("unexpected token [START_ARRAY]", e.getMessage()); - } - } - - /** - * Ensure that we do not accept NaN values - */ - public void testParsingExceptionNaN() { - String json = "[{\"bucket\": {\"job_id\":\"foo\",\"timestamp\":1359453600000,\"bucket_span\":10,\"records\":" - + "[{\"timestamp\":1359453600000,\"bucket_span\":10,\"job_id\":\"foo\",\"probability\":NaN," - + "\"by_field_name\":\"airline\",\"by_field_value\":\"JZA\", \"typical\":[1020.08],\"actual\":[0]," - + "\"field_name\":\"responsetime\",\"function\":\"max\",\"partition_field_name\":\"\",\"partition_field_value\":\"\"}]}}]"; - InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8)); - AutodetectResultsParser parser = new AutodetectResultsParser(); - - expectThrows(XContentParseException.class, - () -> parser.parseResults(inputStream).forEachRemaining(a -> {})); - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java new file mode 100644 index 0000000000000..32ab15a27019f --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/ProcessResultsParserTests.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.process; + +import com.google.common.base.Charsets; +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.test.ESTestCase; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +import static org.hamcrest.Matchers.contains; + +public class ProcessResultsParserTests extends ESTestCase { + + public void testParse_GivenEmptyArray() throws IOException { + String json = "[]"; + try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { + ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER); + assertFalse(parser.parseResults(inputStream).hasNext()); + } + } + + public void testParse_GivenUnknownObject() throws IOException { + String json = "[{\"unknown\":{\"id\": 18}}]"; + try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { + ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER); + XContentParseException e = expectThrows(XContentParseException.class, + () -> parser.parseResults(inputStream).forEachRemaining(a -> { + })); + assertEquals("[1:3] [test_result] unknown field [unknown], parser not found", e.getMessage()); + } + } + + public void testParse_GivenArrayContainsAnotherArray() throws IOException { + String json = "[[]]"; + try (InputStream inputStream = new ByteArrayInputStream(json.getBytes(StandardCharsets.UTF_8))) { + ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER); + ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, + () -> parser.parseResults(inputStream).forEachRemaining(a -> { + })); + assertEquals("unexpected token [START_ARRAY]", e.getMessage()); + } + } + + public void testParseResults() throws IOException { + String input = "[{\"field_1\": \"a\", \"field_2\": 1.0}, {\"field_1\": \"b\", \"field_2\": 2.0}," + + " {\"field_1\": \"c\", \"field_2\": 3.0}]"; + try (InputStream inputStream = new ByteArrayInputStream(input.getBytes(Charsets.UTF_8))) { + + ProcessResultsParser parser = new ProcessResultsParser<>(TestResult.PARSER); + Iterator testResultIterator = parser.parseResults(inputStream); + + List parsedResults = new ArrayList<>(); + while (testResultIterator.hasNext()) { + parsedResults.add(testResultIterator.next()); + } + + assertThat(parsedResults, contains(new TestResult("a", 1.0), new TestResult("b", 2.0), new TestResult("c", 3.0))); + } + } + + private static class TestResult { + + private static final ParseField FIELD_1 = new ParseField("field_1"); + private static final ParseField FIELD_2 = new ParseField("field_2"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("test_result", + a -> new TestResult((String) a[0], (Double) a[1])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_1); + PARSER.declareDouble(ConstructingObjectParser.constructorArg(), FIELD_2); + } + + private final String field1; + private final double field2; + + private TestResult(String field1, double field2) { + this.field1 = field1; + this.field2 = field2; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + TestResult that = (TestResult) other; + return Objects.equals(field1, that.field1) && Objects.equals(field2, that.field2); + } + + @Override + public int hashCode() { + return Objects.hash(field1, field2); + } + } +} From 9bb2a02f44c4b719902ab5d254902afa15514804 Mon Sep 17 00:00:00 2001 From: Hendrik Muhs Date: Tue, 18 Dec 2018 08:53:07 +0100 Subject: [PATCH 11/67] change the download location of the ml native code build (#36733) change the download location to load the custom binaries created in elastic/ml-cpp#344 --- x-pack/plugin/ml/cpp-snapshot/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/cpp-snapshot/build.gradle b/x-pack/plugin/ml/cpp-snapshot/build.gradle index 1c35d9db6f321..85455d6311ea0 100644 --- a/x-pack/plugin/ml/cpp-snapshot/build.gradle +++ b/x-pack/plugin/ml/cpp-snapshot/build.gradle @@ -9,7 +9,7 @@ ext.version = VersionProperties.elasticsearch // for this project so it can be used with dependency substitution. void getZip(File snapshotZip) { - String zipUrl = "http://prelert-artifacts.s3.amazonaws.com/maven/org/elasticsearch/ml/ml-cpp/${version}/ml-cpp-${version}.zip" + String zipUrl = "http://prelert-artifacts.s3.amazonaws.com/maven/org/elasticsearch/ml/ml-cpp-df/${version}/ml-cpp-df-${version}.zip" File snapshotMd5 = new File(snapshotZip.toString() + '.md5') HttpURLConnection conn = (HttpURLConnection) new URL(zipUrl).openConnection(); From 371f5d7e9324fd5198ec077be5f32dd471a3521a Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 11 Jan 2019 17:04:07 +0200 Subject: [PATCH 12/67] [FEATURE][ML] Fix a few minor issues after merging master --- .../xpack/ml/action/TransportRunAnalyticsAction.java | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java index f8f1226492941..7a8367f861557 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java @@ -87,14 +87,14 @@ public TransportRunAnalyticsAction(ThreadPool threadPool, TransportService trans @Override protected void doExecute(Task task, RunAnalyticsAction.Request request, ActionListener listener) { DiscoveryNode localNode = clusterService.localNode(); - if (isMlNode(localNode)) { + if (MachineLearning.isMlNode(localNode)) { reindexDataframeAndStartAnalysis(request.getIndex(), listener); return; } ClusterState clusterState = clusterService.state(); for (DiscoveryNode node : clusterState.getNodes()) { - if (isMlNode(node)) { + if (MachineLearning.isMlNode(node)) { transportService.sendRequest(node, actionName, request, new ActionListenerResponseHandler<>(listener, inputStream -> { AcknowledgedResponse response = new AcknowledgedResponse(); @@ -107,12 +107,6 @@ protected void doExecute(Task task, RunAnalyticsAction.Request request, ActionLi listener.onFailure(ExceptionsHelper.badRequestException("No ML node to run on")); } - private boolean isMlNode(DiscoveryNode node) { - Map nodeAttributes = node.getAttributes(); - String enabled = nodeAttributes.get(MachineLearning.ML_ENABLED_NODE_ATTR); - return Boolean.valueOf(enabled); - } - private void reindexDataframeAndStartAnalysis(String index, ActionListener listener) { final String destinationIndex = index + "_copy"; From 9fd1d101c33b132f55266375d4ecea18d251b057 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 11 Jan 2019 19:15:19 +0200 Subject: [PATCH 13/67] [FEATURE][ML] Add checksum checks on dataframe result joining (#37259) In order to sanity check that analytics results are joined correctly with their corresponding dataframe rows, we write a checksum for each dataframe row which is a 32-bit hash of the analysis fields. The analytics process includes it in the results. Upon joining we check that the checksums match. --- .../ml/analytics/DataFrameDataExtractor.java | 5 +++ .../DataFrameDataExtractorFactory.java | 5 ++- .../process/AnalyticsProcessManager.java | 14 +++++--- .../ml/analytics/process/AnalyticsResult.java | 22 ++++++------- .../process/AnalyticsResultProcessor.java | 12 +++++++ .../NativeAnalyticsProcessFactory.java | 4 +-- .../DataFrameDataExtractorFactoryTests.java | 24 ++++++++++++++ .../AnalyticsResultProcessorTests.java | 33 ++++++++++++++++--- .../process/AnalyticsResultTests.java | 4 +-- 9 files changed, 98 insertions(+), 25 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java index 3d47aeff1b8b5..ee03e7660ef87 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.NoSuchElementException; import java.util.Objects; @@ -238,5 +239,9 @@ public SearchHit getHit() { public boolean shouldSkip() { return values == null; } + + public int getChecksum() { + return Arrays.hashCode(values); + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java index 57c2b44f769c4..c3622b49dab82 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java @@ -109,7 +109,10 @@ static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapa Set fields = fieldCapabilitiesResponse.get().keySet(); fields.removeAll(IGNORE_FIELDS); removeFieldsWithIncompatibleTypes(fields, fieldCapabilitiesResponse); - ExtractedFields extractedFields = ExtractedFields.build(new ArrayList<>(fields), Collections.emptySet(), fieldCapabilitiesResponse) + List sortedFields = new ArrayList<>(fields); + // We sort the fields to ensure the checksum for each document is deterministic + Collections.sort(sortedFields); + ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse) .filterFields(ExtractedField.ExtractionMethod.DOC_VALUE); if (extractedFields.getAllFields().isEmpty()) { throw ExceptionsHelper.badRequestException("No compatible fields could be detected"); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java index bf079348c7e55..2b55e8c648d1d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java @@ -78,8 +78,8 @@ private void processData(String jobId, DataFrameDataExtractor dataExtractor, Ana } private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { - // The extra field is the control field (should be an empty string) - String[] record = new String[dataExtractor.getFieldNames().size() + 1]; + // The extra fields are for the doc hash and the control field (should be an empty string) + String[] record = new String[dataExtractor.getFieldNames().size() + 2]; // The value of the control field should be an empty string for data frame rows record[record.length - 1] = ""; @@ -90,6 +90,7 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces if (row.shouldSkip() == false) { String[] rowValues = row.getValues(); System.arraycopy(rowValues, 0, record, 0, rowValues.length); + record[record.length - 2] = String.valueOf(row.getChecksum()); process.writeRecord(record); } } @@ -99,11 +100,16 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsProcess process) throws IOException { List fieldNames = dataExtractor.getFieldNames(); - String[] headerRecord = new String[fieldNames.size() + 1]; + + // We add 2 extra fields, both named dot: + // - the document hash + // - the control message + String[] headerRecord = new String[fieldNames.size() + 2]; for (int i = 0; i < fieldNames.size(); i++) { headerRecord[i] = fieldNames.get(i); } - // The field name of the control field is dot + + headerRecord[headerRecord.length - 2] = "."; headerRecord[headerRecord.length - 1] = "."; process.writeRecord(headerRecord); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java index 1f9ef71da8fb1..3b34537a56e81 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java @@ -17,27 +17,27 @@ public class AnalyticsResult implements ToXContentObject { public static final ParseField TYPE = new ParseField("analytics_result"); - public static final ParseField ID_HASH = new ParseField("id_hash"); + public static final ParseField CHECKSUM = new ParseField("checksum"); public static final ParseField RESULTS = new ParseField("results"); static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), - a -> new AnalyticsResult((String) a[0], (Map) a[1])); + a -> new AnalyticsResult((Integer) a[0], (Map) a[1])); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), ID_HASH); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), CHECKSUM); PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS); } - private final String idHash; + private final int checksum; private final Map results; - public AnalyticsResult(String idHash, Map results) { - this.idHash = Objects.requireNonNull(idHash); + public AnalyticsResult(int checksum, Map results) { + this.checksum = Objects.requireNonNull(checksum); this.results = Objects.requireNonNull(results); } - public String getIdHash() { - return idHash; + public int getChecksum() { + return checksum; } public Map getResults() { @@ -47,7 +47,7 @@ public Map getResults() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(ID_HASH.getPreferredName(), idHash); + builder.field(CHECKSUM.getPreferredName(), checksum); builder.field(RESULTS.getPreferredName(), results); builder.endObject(); return builder; @@ -63,11 +63,11 @@ public boolean equals(Object other) { } AnalyticsResult that = (AnalyticsResult) other; - return Objects.equals(idHash, that.idHash) && Objects.equals(results, that.results); + return checksum == that.checksum && Objects.equals(results, that.results); } @Override public int hashCode() { - return Objects.hash(idHash, results); + return Objects.hash(checksum, results); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java index bdb1526b1b78a..0dbbf1b8b22d7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java @@ -96,6 +96,8 @@ private void joinCurrentResults() { continue; } AnalyticsResult result = currentResults.get(i); + checkChecksumsMatch(row, result); + SearchHit hit = row.getHit(); Map source = new LinkedHashMap(hit.getSourceAsMap()); source.putAll(result.getResults()); @@ -112,4 +114,14 @@ private void joinCurrentResults() { } } } + + private void checkChecksumsMatch(DataFrameDataExtractor.Row row, AnalyticsResult result) { + if (row.getChecksum() != result.getChecksum()) { + String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; "; + msg += "expected [" + row.getChecksum() + "] but result had [" + result.getChecksum() + "]; "; + msg += "this implies the data frame index [" + row.getHit().getIndex() + "] was modified while the analysis was running. "; + msg += "We rely on this index being immutable during a running analysis and so the results will be unreliable."; + throw new IllegalStateException(msg); + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java index 3943e7fd7924f..0b16bdb715bb3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java @@ -45,8 +45,8 @@ public AnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessCon ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId, true, false, true, true, false, false); - // The extra 1 is the control field - int numberOfFields = analyticsProcessConfig.cols() + 1; + // The extra 2 are for the checksum and the control field + int numberOfFields = analyticsProcessConfig.cols() + 2; createNativeProcess(jobId, analyticsProcessConfig, filesToDelete, processPipes); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java index 1a43b2893baef..efbc19563cc4d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java @@ -12,6 +12,8 @@ import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -89,6 +91,28 @@ public void testDetectExtractedFields_GivenIgnoredField() { assertThat(e.getMessage(), equalTo("No compatible fields could be detected")); } + public void testDetectExtractedFields_ShouldSortFieldsAlphabetically() { + int fieldCount = randomIntBetween(10, 20); + List fields = new ArrayList<>(); + for (int i = 0; i < fieldCount; i++) { + fields.add(randomAlphaOfLength(20)); + } + List sortedFields = new ArrayList<>(fields); + Collections.sort(sortedFields); + + MockFieldCapsResponseBuilder mockFieldCapsResponseBuilder = new MockFieldCapsResponseBuilder(); + for (String field : fields) { + mockFieldCapsResponseBuilder.addAggregatableField(field, "float"); + } + FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); + + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities); + + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, equalTo(sortedFields)); + } + private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java index 3d71b4d430e2b..e4de63908a5f9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java @@ -63,13 +63,13 @@ public void testProcess_GivenSingleRowAndResult() throws IOException { String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; String[] dataValues = {"42.0"}; - DataFrameDataExtractor.Row row = newRow(newHit("1", dataDoc), dataValues); + DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); givenSingleDataFrameBatch(Arrays.asList(row)); Map resultFields = new HashMap<>(); resultFields.put("a", "1"); resultFields.put("b", "2"); - AnalyticsResult result = new AnalyticsResult("some_hash", resultFields); + AnalyticsResult result = new AnalyticsResult(1, resultFields); givenProcessResults(Arrays.asList(result)); AnalyticsResultProcessor resultProcessor = createResultProcessor(); @@ -90,6 +90,28 @@ public void testProcess_GivenSingleRowAndResult() throws IOException { assertThat(indexedDocSource.get("b"), equalTo("2")); } + public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); + givenSingleDataFrameBatch(Arrays.asList(row)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + AnalyticsResult result = new AnalyticsResult(2, resultFields); + givenProcessResults(Arrays.asList(result)); + + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verifyNoMoreInteractions(client); + } + private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } @@ -99,16 +121,17 @@ private void givenSingleDataFrameBatch(List batch) t when(dataExtractor.next()).thenReturn(Optional.of(batch)).thenReturn(Optional.empty()); } - private static SearchHit newHit(String id, String json) { - SearchHit hit = new SearchHit(42, id, new Text("doc"), Collections.emptyMap()); + private static SearchHit newHit(String json) { + SearchHit hit = new SearchHit(randomInt(), randomAlphaOfLength(10), new Text("doc"), Collections.emptyMap()); hit.sourceRef(new BytesArray(json)); return hit; } - private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values) { + private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, int checksum) { DataFrameDataExtractor.Row row = mock(DataFrameDataExtractor.Row.class); when(row.getHit()).thenReturn(hit); when(row.getValues()).thenReturn(values); + when(row.getChecksum()).thenReturn(checksum); return row; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java index 6250a96cd3284..c243e4c871d78 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java @@ -16,7 +16,7 @@ public class AnalyticsResultTests extends AbstractXContentTestCase results = new HashMap<>(); int resultsSize = randomIntBetween(1, 10); for (int i = 0; i < resultsSize; i++) { @@ -24,7 +24,7 @@ protected AnalyticsResult createTestInstance() { Object resultValue = randomBoolean() ? randomAlphaOfLength(20) : randomDouble(); results.put(resultField, resultValue); } - return new AnalyticsResult(idHash, results); + return new AnalyticsResult(checksum, results); } @Override From 16552643e6a67c00afe0e244d3a8daa7918ff4da Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 21 Jan 2019 11:12:23 +0200 Subject: [PATCH 14/67] [ML] Rename analytics package to dataframe (#37583) --- .../java/org/elasticsearch/xpack/ml/MachineLearning.java | 6 +++--- .../xpack/ml/action/TransportRunAnalyticsAction.java | 6 +++--- .../ml/{analytics => dataframe}/DataFrameAnalysis.java | 2 +- .../{analytics => dataframe}/DataFrameDataExtractor.java | 2 +- .../DataFrameDataExtractorContext.java | 2 +- .../DataFrameDataExtractorFactory.java | 2 +- .../ml/{analytics => dataframe}/DataFrameFields.java | 2 +- .../process/AnalyticsBuilder.java | 2 +- .../process/AnalyticsControlMessageWriter.java | 2 +- .../process/AnalyticsProcess.java | 2 +- .../process/AnalyticsProcessConfig.java | 4 ++-- .../process/AnalyticsProcessFactory.java | 2 +- .../process/AnalyticsProcessManager.java | 8 ++++---- .../{analytics => dataframe}/process/AnalyticsResult.java | 2 +- .../process/AnalyticsResultProcessor.java | 4 ++-- .../process/NativeAnalyticsProcess.java | 2 +- .../process/NativeAnalyticsProcessFactory.java | 2 +- .../DataFrameDataExtractorFactoryTests.java | 2 +- .../process/AnalyticsControlMessageWriterTests.java | 2 +- .../process/AnalyticsResultProcessorTests.java | 4 ++-- .../process/AnalyticsResultTests.java | 2 +- 21 files changed, 31 insertions(+), 31 deletions(-) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/DataFrameAnalysis.java (95%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/DataFrameDataExtractor.java (99%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/DataFrameDataExtractorContext.java (96%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/DataFrameDataExtractorFactory.java (99%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/DataFrameFields.java (88%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsBuilder.java (98%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsControlMessageWriter.java (96%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsProcess.java (94%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsProcessConfig.java (93%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsProcessFactory.java (93%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsProcessManager.java (96%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsResult.java (97%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsResultProcessor.java (97%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/NativeAnalyticsProcess.java (97%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/NativeAnalyticsProcessFactory.java (98%) rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/DataFrameDataExtractorFactoryTests.java (99%) rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsControlMessageWriterTests.java (97%) rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsResultProcessorTests.java (98%) rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/{analytics => dataframe}/process/AnalyticsResultTests.java (96%) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index bd80951d5c758..1fce028c727df 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -165,9 +165,9 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; -import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessFactory; -import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessManager; -import org.elasticsearch.xpack.ml.analytics.process.NativeAnalyticsProcessFactory; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; +import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory; import org.elasticsearch.xpack.ml.datafeed.DatafeedJobBuilder; import org.elasticsearch.xpack.ml.datafeed.DatafeedManager; import org.elasticsearch.xpack.ml.datafeed.persistence.DatafeedConfigProvider; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java index 7a8367f861557..cb9e25504a8e4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java @@ -38,9 +38,9 @@ import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractorFactory; -import org.elasticsearch.xpack.ml.analytics.DataFrameFields; -import org.elasticsearch.xpack.ml.analytics.process.AnalyticsProcessManager; +import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.dataframe.DataFrameFields; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; import java.util.Arrays; import java.util.Collections; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalysis.java similarity index 95% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameAnalysis.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalysis.java index 1b06e77e31b5a..81062f9795040 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameAnalysis.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalysis.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics; +package org.elasticsearch.xpack.ml.dataframe; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ToXContentObject; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractor.java similarity index 99% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractor.java index ee03e7660ef87..3d17ff7afd2c5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractor.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics; +package org.elasticsearch.xpack.ml.dataframe; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorContext.java similarity index 96% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorContext.java index d1b52bdac0351..82de257ccaffd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorContext.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics; +package org.elasticsearch.xpack.ml.dataframe; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactory.java similarity index 99% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactory.java index c3622b49dab82..3080cd4b43a27 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactory.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics; +package org.elasticsearch.xpack.ml.dataframe; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameFields.java similarity index 88% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameFields.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameFields.java index 8e7a8dd61a8ed..164b7888a6ffe 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/DataFrameFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameFields.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics; +package org.elasticsearch.xpack.ml.dataframe; public final class DataFrameFields { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsBuilder.java similarity index 98% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsBuilder.java index e2b81ff547cc6..4d58a132bab5b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsBuilder.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ToXContent; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsControlMessageWriter.java similarity index 96% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsControlMessageWriter.java index 0500b51f85b2a..c66754171fc29 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsControlMessageWriter.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.xpack.ml.process.writer.AbstractControlMsgWriter; import org.elasticsearch.xpack.ml.process.writer.LengthEncodedWriter; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java similarity index 94% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java index dc07d688a67ae..0d925ed439d3f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.xpack.ml.process.NativeProcess; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java similarity index 93% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 4ee543ed078a4..ce186b9232da3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -3,12 +3,12 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.ml.analytics.DataFrameAnalysis; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalysis; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessFactory.java similarity index 93% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessFactory.java index d0eb7a414074f..d09757ddc5c74 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessFactory.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import java.util.concurrent.ExecutorService; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java similarity index 96% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 2b55e8c648d1d..987335d22f0a6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -15,9 +15,9 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.analytics.DataFrameAnalysis; -import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; -import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalysis; +import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractorFactory; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java similarity index 97% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java index 3b34537a56e81..3e9c1b8b9cd57 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.ConstructingObjectParser; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java similarity index 97% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 0dbbf1b8b22d7..f6be7a4e78e36 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -14,7 +14,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.client.Client; import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractor; import java.util.ArrayList; import java.util.Iterator; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java similarity index 97% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java index 5f0f58e8b7b8a..8934de51f3772 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.xpack.ml.process.AbstractNativeProcess; import org.elasticsearch.xpack.ml.process.ProcessResultsParser; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java similarity index 98% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java index 0b16bdb715bb3..4cb6f344c7019 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/analytics/process/NativeAnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactoryTests.java similarity index 99% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactoryTests.java index efbc19563cc4d..82c492f579e10 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/DataFrameDataExtractorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactoryTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics; +package org.elasticsearch.xpack.ml.dataframe; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.fieldcaps.FieldCapabilities; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsControlMessageWriterTests.java similarity index 97% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsControlMessageWriterTests.java index 3845b9df26b4c..5f0cce1770227 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsControlMessageWriterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsControlMessageWriterTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.process.writer.LengthEncodedWriter; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java similarity index 98% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index e4de63908a5f9..1d38049ff23ce 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.bulk.BulkAction; @@ -16,7 +16,7 @@ import org.elasticsearch.common.text.Text; import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.ml.analytics.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractor; import org.junit.Before; import org.mockito.ArgumentCaptor; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java similarity index 96% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java index c243e4c871d78..fc46b4e984d26 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/analytics/process/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.analytics.process; +package org.elasticsearch.xpack.ml.dataframe.process; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; From 551f9a870047132e0846c05bb55a9c8e2a5177af Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 4 Feb 2019 14:39:18 +0200 Subject: [PATCH 15/67] [ML] Data frame analytics as persistent tasks (#37919) Converts data frame analytics to run as persistent tasks. Adds the following APIs: - PUT _ml/data_frame/analysis/{id} - GET _ml/data_frame/analysis/{id} - GET _ml/data_frame/analysis/{id}/_stats - POST _ml/data_frame/analysis/{id}/_start - DELETE _ml/data_frame/analysis/{id} --- .../xpack/core/XPackClientPlugin.java | 14 +- .../elasticsearch/xpack/core/ml/MlTasks.java | 27 ++ .../action/AbstractGetResourcesRequest.java | 75 ++++ .../action/AbstractGetResourcesResponse.java | 84 ++++ .../DeleteDataFrameAnalyticsAction.java | 81 ++++ .../action/GetDataFrameAnalyticsAction.java | 69 ++++ .../GetDataFrameAnalyticsStatsAction.java | 279 +++++++++++++ .../action/PutDataFrameAnalyticsAction.java | 153 ++++++++ .../core/ml/action/RunAnalyticsAction.java | 109 ------ .../action/StartDataFrameAnalyticsAction.java | 165 ++++++++ .../ml/dataframe/DataFrameAnalysisConfig.java | 66 ++++ .../dataframe/DataFrameAnalyticsConfig.java | 163 ++++++++ .../ml/dataframe/DataFrameAnalyticsState.java | 36 ++ .../DataFrameAnalyticsTaskState.java | 97 +++++ .../persistence/ElasticsearchMappings.java | 28 ++ .../ml/job/results/ReservedFieldNames.java | 9 + .../core/ml/process/writer/RecordWriter.java | 2 +- .../xpack/core/ml/utils/ExceptionsHelper.java | 13 + ...DataFrameAnalyticsActionResponseTests.java | 33 ++ .../GetDataFrameAnalyticsRequestTests.java | 27 ++ ...rameAnalyticsStatsActionResponseTests.java | 36 ++ ...etDataFrameAnalyticsStatsRequestTests.java | 26 ++ ...tDataFrameAnalyticsActionRequestTests.java | 42 ++ ...DataFrameAnalyticsActionResponseTests.java | 23 ++ .../StartDataFrameAnalyticsRequestTests.java | 23 ++ .../DataFrameAnalysisConfigTests.java | 47 +++ .../DataFrameAnalyticsConfigTests.java | 45 +++ .../integration/MlRestTestStateCleaner.java | 17 + .../ml/qa/ml-with-security/build.gradle | 14 + .../xpack/ml/MachineLearning.java | 54 ++- .../AbstractTransportGetResourcesAction.java | 136 +++++++ ...ansportDeleteDataFrameAnalyticsAction.java | 96 +++++ .../TransportGetDataFrameAnalyticsAction.java | 66 ++++ ...sportGetDataFrameAnalyticsStatsAction.java | 97 +++++ .../TransportPutDataFrameAnalyticsAction.java | 74 ++++ ...ransportStartDataFrameAnalyticsAction.java | 192 +++++++++ .../xpack/ml/dataframe/DataFrameAnalysis.java | 32 -- ...lds.java => DataFrameAnalyticsFields.java} | 4 +- .../DataFrameAnalyticsManager.java} | 183 ++++----- .../analyses/AbstractDataFrameAnalysis.java | 28 ++ .../analyses/DataFrameAnalysesUtils.java | 80 ++++ .../dataframe/analyses/DataFrameAnalysis.java | 47 +++ .../dataframe/analyses/OutlierDetection.java | 64 +++ .../DataFrameDataExtractor.java | 7 +- .../DataFrameDataExtractorContext.java | 2 +- .../DataFrameDataExtractorFactory.java | 17 +- .../DataFrameAnalyticsConfigProvider.java | 93 +++++ .../dataframe/process/AnalyticsProcess.java | 2 +- .../process/AnalyticsProcessConfig.java | 2 +- .../process/AnalyticsProcessManager.java | 37 +- .../process/AnalyticsResultProcessor.java | 6 +- .../NativeAnalyticsProcessFactory.java | 4 +- .../writer/AbstractDataToProcessWriter.java | 4 +- .../RestDeleteDataFrameAnalyticsAction.java | 39 ++ .../RestGetDataFrameAnalyticsAction.java | 50 +++ .../RestGetDataFrameAnalyticsStatsAction.java | 50 +++ .../RestPutDataFrameAnalyticsAction.java | 43 ++ .../RestStartDataFrameAnalyticsAction.java} | 22 +- .../analyses/DataFrameAnalysesUtilsTests.java | 77 ++++ .../analyses/OutlierDetectionTests.java | 60 +++ .../DataFrameDataExtractorFactoryTests.java | 24 +- .../AnalyticsResultProcessorTests.java | 2 +- .../api/ml.delete_data_frame_analytics.json | 17 + .../api/ml.get_data_frame_analytics.json | 29 ++ .../ml.get_data_frame_analytics_stats.json | 29 ++ .../api/ml.put_data_frame_analytics.json | 20 + .../api/ml.start_data_frame_analytics.json | 16 + .../test/ml/data_frame_analytics_crud.yml | 370 ++++++++++++++++++ .../test/ml/start_data_frame_analytics.yml | 46 +++ 69 files changed, 3715 insertions(+), 309 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesResponse.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/RunAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfigTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/AbstractTransportGetResourcesAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalysis.java rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/{DataFrameFields.java => DataFrameAnalyticsFields.java} (79%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/{action/TransportRunAnalyticsAction.java => dataframe/DataFrameAnalyticsManager.java} (50%) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/AbstractDataFrameAnalysis.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtils.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/{ => extractor}/DataFrameDataExtractor.java (97%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/{ => extractor}/DataFrameDataExtractorContext.java (95%) rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/{ => extractor}/DataFrameDataExtractorFactory.java (90%) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestDeleteDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestPutDataFrameAnalyticsAction.java rename x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/{analytics/RestRunAnalyticsAction.java => dataframe/RestStartDataFrameAnalyticsAction.java} (50%) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtilsTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetectionTests.java rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/{ => extractor}/DataFrameDataExtractorFactoryTests.java (91%) create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 301cda08d5dfb..c765b0390bf80 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -53,9 +53,9 @@ import org.elasticsearch.xpack.core.indexlifecycle.IndexLifecycleMetadata; import org.elasticsearch.xpack.core.indexlifecycle.LifecycleAction; import org.elasticsearch.xpack.core.indexlifecycle.LifecycleType; -import org.elasticsearch.xpack.core.indexlifecycle.SetPriorityAction; import org.elasticsearch.xpack.core.indexlifecycle.ReadOnlyAction; import org.elasticsearch.xpack.core.indexlifecycle.RolloverAction; +import org.elasticsearch.xpack.core.indexlifecycle.SetPriorityAction; import org.elasticsearch.xpack.core.indexlifecycle.ShrinkAction; import org.elasticsearch.xpack.core.indexlifecycle.TimeseriesLifecycleType; import org.elasticsearch.xpack.core.indexlifecycle.UnfollowAction; @@ -87,6 +87,8 @@ import org.elasticsearch.xpack.core.ml.action.GetCalendarEventsAction; import org.elasticsearch.xpack.core.ml.action.GetCalendarsAction; import org.elasticsearch.xpack.core.ml.action.GetCategoriesAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetFiltersAction; @@ -105,12 +107,13 @@ import org.elasticsearch.xpack.core.ml.action.PostDataAction; import org.elasticsearch.xpack.core.ml.action.PreviewDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutCalendarAction; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutFilterAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; -import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; @@ -294,8 +297,11 @@ public List> getClientActions() { PostCalendarEventsAction.INSTANCE, PersistJobAction.INSTANCE, FindFileStructureAction.INSTANCE, - RunAnalyticsAction.INSTANCE, SetUpgradeModeAction.INSTANCE, + PutDataFrameAnalyticsAction.INSTANCE, + GetDataFrameAnalyticsAction.INSTANCE, + GetDataFrameAnalyticsStatsAction.INSTANCE, + StartDataFrameAnalyticsAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, @@ -450,6 +456,8 @@ public List getNamedXContent() { StartDatafeedAction.DatafeedParams::fromXContent), new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(MlTasks.JOB_TASK_NAME), OpenJobAction.JobParams::fromXContent), + new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME), + StartDataFrameAnalyticsAction.TaskParams::fromXContent), // ML - Task states new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(DatafeedState.NAME), DatafeedState::fromXContent), new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(JobTaskState.NAME), JobTaskState::fromXContent), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index cd32505a48e3e..649a77648eafb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -11,6 +11,8 @@ import org.elasticsearch.persistent.PersistentTasksClusterService; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; @@ -23,9 +25,11 @@ public final class MlTasks { public static final String JOB_TASK_NAME = "xpack/ml/job"; public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed"; + public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics"; public static final String JOB_TASK_ID_PREFIX = "job-"; public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-"; + private static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; public static final PersistentTasksCustomMetaData.Assignment AWAITING_UPGRADE = new PersistentTasksCustomMetaData.Assignment(null, @@ -50,6 +54,13 @@ public static String datafeedTaskId(String datafeedId) { return DATAFEED_TASK_ID_PREFIX + datafeedId; } + /** + * Namespaces the task ids for data frame analytics. + */ + public static String dataFrameAnalyticsTaskId(String id) { + return DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + id; + } + @Nullable public static PersistentTasksCustomMetaData.PersistentTask getJobTask(String jobId, @Nullable PersistentTasksCustomMetaData tasks) { return tasks == null ? null : tasks.getTask(jobTaskId(jobId)); @@ -61,6 +72,12 @@ public static PersistentTasksCustomMetaData.PersistentTask getDatafeedTask(St return tasks == null ? null : tasks.getTask(datafeedTaskId(datafeedId)); } + @Nullable + public static PersistentTasksCustomMetaData.PersistentTask getDataFrameAnalyticsTask(String analyticsId, + @Nullable PersistentTasksCustomMetaData tasks) { + return tasks == null ? null : tasks.getTask(dataFrameAnalyticsTaskId(analyticsId)); + } + /** * Note that the return value of this method does NOT take node relocations into account. * Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most @@ -120,6 +137,16 @@ public static DatafeedState getDatafeedState(String datafeedId, @Nullable Persis } } + public static DataFrameAnalyticsState getDataFrameAnalyticsState(String analyticsId, @Nullable PersistentTasksCustomMetaData tasks) { + PersistentTasksCustomMetaData.PersistentTask task = getDataFrameAnalyticsTask(analyticsId, tasks); + if (task != null && task.getState() != null) { + DataFrameAnalyticsTaskState taskState = (DataFrameAnalyticsTaskState) task.getState(); + return taskState.getState(); + } else { + return DataFrameAnalyticsState.STOPPED; + } + } + /** * The job Ids of anomaly detector job tasks. * All anomaly detector jobs are returned regardless of the status of the diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java new file mode 100644 index 0000000000000..baa7d2714ec8b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.action.util.PageParams; + +import java.io.IOException; +import java.util.Objects; + +public abstract class AbstractGetResourcesRequest extends ActionRequest { + + private String resourceId; + private PageParams pageParams = PageParams.defaultParams(); + + public void setResourceId(String resourceId) { + this.resourceId = resourceId; + } + + public String getResourceId() { + return resourceId; + } + + public void setPageParams(PageParams pageParams) { + this.pageParams = pageParams; + } + + public PageParams getPageParams() { + return pageParams; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + resourceId = in.readOptionalString(); + pageParams = in.readOptionalWriteable(PageParams::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(resourceId); + out.writeOptionalWriteable(pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(resourceId, pageParams); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (obj instanceof AbstractGetResourcesRequest == false) { + return false; + } + AbstractGetResourcesRequest other = (AbstractGetResourcesRequest) obj; + return Objects.equals(resourceId, other.resourceId); + } + + public abstract String getResourceIdField(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesResponse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesResponse.java new file mode 100644 index 0000000000000..7f7686ff230ba --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesResponse.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.StatusToXContentObject; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; + +import java.io.IOException; +import java.util.Objects; + +public abstract class AbstractGetResourcesResponse extends ActionResponse + implements StatusToXContentObject { + + private QueryPage resources; + + protected AbstractGetResourcesResponse() {} + + protected AbstractGetResourcesResponse(QueryPage resources) { + this.resources = Objects.requireNonNull(resources); + } + + public QueryPage getResources() { + return resources; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + resources = new QueryPage<>(in, getReader()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + resources.writeTo(out); + } + + @Override + public RestStatus status() { + return RestStatus.OK; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + resources.doXContentBody(builder, params); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(resources); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (obj instanceof AbstractGetResourcesResponse == false) { + return false; + } + AbstractGetResourcesResponse other = (AbstractGetResourcesResponse) obj; + return Objects.equals(resources, other.resources); + } + + @Override + public final String toString() { + return Strings.toString(this); + } + protected abstract Reader getReader(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..27ee6b0a97139 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeOperationRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DeleteDataFrameAnalyticsAction extends Action { + + public static final DeleteDataFrameAnalyticsAction INSTANCE = new DeleteDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/delete"; + + private DeleteDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public AcknowledgedResponse newResponse() { + return new AcknowledgedResponse(); + } + + public static class Request extends AcknowledgedRequest implements ToXContentFragment { + + private String id; + + public Request() {} + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public String getId() { + return id; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DeleteDataFrameAnalyticsAction.Request request = (DeleteDataFrameAnalyticsAction.Request) o; + return Objects.equals(id, request.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + + public static class RequestBuilder extends MasterNodeOperationRequestBuilder { + + protected RequestBuilder(ElasticsearchClient client, DeleteDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..264b996b3e8f4 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; + +import java.io.IOException; +import java.util.Collections; + +public class GetDataFrameAnalyticsAction extends Action { + + public static final GetDataFrameAnalyticsAction INSTANCE = new GetDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/get"; + + private GetDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(new QueryPage<>(Collections.emptyList(), 0, Response.RESULTS_FIELD)); + } + + public static class Request extends AbstractGetResourcesRequest { + + public Request() {} + + public Request(StreamInput in) throws IOException { + readFrom(in); + } + + @Override + public String getResourceIdField() { + return DataFrameAnalyticsConfig.ID.getPreferredName(); + } + } + + public static class Response extends AbstractGetResourcesResponse { + + public static final ParseField RESULTS_FIELD = new ParseField("data_frame_analytics"); + + public Response() {} + + public Response(QueryPage analytics) { + super(analytics); + } + + @Override + protected Reader getReader() { + return DataFrameAnalyticsConfig::new; + } + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client) { + super(client, INSTANCE, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java new file mode 100644 index 0000000000000..e7824bb08fb7a --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -0,0 +1,279 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class GetDataFrameAnalyticsStatsAction extends Action { + + public static final GetDataFrameAnalyticsStatsAction INSTANCE = new GetDataFrameAnalyticsStatsAction(); + public static final String NAME = "cluster:monitor/xpack/ml/data_frame/analytics/stats/get"; + + private GetDataFrameAnalyticsStatsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return Response::new; + } + + public static class Request extends MasterNodeRequest { + + private String id; + private PageParams pageParams = PageParams.defaultParams(); + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID.getPreferredName()); + } + + public Request() {} + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + pageParams = in.readOptionalWriteable(PageParams::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeOptionalWriteable(pageParams); + } + + public void setId(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public void setPageParams(PageParams pageParams) { + this.pageParams = pageParams; + } + + public PageParams getPageParams() { + return pageParams; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public int hashCode() { + return Objects.hash(id, pageParams); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Request other = (Request) obj; + return Objects.equals(id, other.id) && Objects.equals(pageParams, other.pageParams); + } + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client, GetDataFrameAnalyticsStatsAction action) { + super(client, action, new Request()); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + public static class Stats implements ToXContentObject, Writeable { + + private final String id; + private final DataFrameAnalyticsState state; + @Nullable + private final DiscoveryNode node; + @Nullable + private final String assignmentExplanation; + + public Stats(String id, DataFrameAnalyticsState state, @Nullable DiscoveryNode node, + @Nullable String assignmentExplanation) { + this.id = Objects.requireNonNull(id); + this.state = Objects.requireNonNull(state); + this.node = node; + this.assignmentExplanation = assignmentExplanation; + } + + public Stats(StreamInput in) throws IOException { + id = in.readString(); + state = DataFrameAnalyticsState.fromStream(in); + node = in.readOptionalWriteable(DiscoveryNode::new); + assignmentExplanation = in.readOptionalString(); + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsState getState() { + return state; + } + + public DiscoveryNode getNode() { + return node; + } + + public String getAssignmentExplanation() { + return assignmentExplanation; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // TODO: Have callers wrap the content with an object as they choose rather than forcing it upon them + builder.startObject(); + { + toUnwrappedXContent(builder); + } + return builder.endObject(); + } + + public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOException { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + builder.field("state", state.toString()); + if (node != null) { + builder.startObject("node"); + builder.field("id", node.getId()); + builder.field("name", node.getName()); + builder.field("ephemeral_id", node.getEphemeralId()); + builder.field("transport_address", node.getAddress().toString()); + + builder.startObject("attributes"); + for (Map.Entry entry : node.getAttributes().entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + builder.endObject(); + } + if (assignmentExplanation != null) { + builder.field("assignment_explanation", assignmentExplanation); + } + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + state.writeTo(out); + out.writeOptionalWriteable(node); + out.writeOptionalString(assignmentExplanation); + } + + @Override + public int hashCode() { + return Objects.hash(id, state, node, assignmentExplanation); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Stats other = (Stats) obj; + return Objects.equals(id, other.id) + && Objects.equals(this.state, other.state) + && Objects.equals(this.node, other.node) + && Objects.equals(this.assignmentExplanation, other.assignmentExplanation); + } + } + + private QueryPage stats; + + public Response() {} + + public Response(QueryPage stats) { + this.stats = stats; + } + + public Response(StreamInput in) throws IOException { + super(in); + stats = new QueryPage<>(in, Stats::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + stats.writeTo(out); + } + + public QueryPage getResponse() { + return stats; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + stats.doXContentBody(builder, params); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(stats); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Response other = (Response) obj; + return Objects.equals(stats, other.stats); + } + + @Override + public final String toString() { + return Strings.toString(this); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..e447aa70109e7 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.MasterNodeOperationRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; + +import java.io.IOException; +import java.util.Objects; + +public class PutDataFrameAnalyticsAction extends Action { + + public static final PutDataFrameAnalyticsAction INSTANCE = new PutDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/put"; + + private PutDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(); + } + + public static class Request extends AcknowledgedRequest implements ToXContentObject { + + public static Request parseRequest(String id, XContentParser parser) { + DataFrameAnalyticsConfig.Builder config = DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null); + if (config.getId() == null) { + config.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(config.getId())) { + // If we have both URI and body ID, they must be identical + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + config.getId(), id)); + } + + return new PutDataFrameAnalyticsAction.Request(config.build()); + } + + private DataFrameAnalyticsConfig config; + + public Request() {} + + public Request(DataFrameAnalyticsConfig config) { + this.config = config; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + config = new DataFrameAnalyticsConfig(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + config.writeTo(out); + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + config.toXContent(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PutDataFrameAnalyticsAction.Request request = (PutDataFrameAnalyticsAction.Request) o; + return Objects.equals(config, request.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private DataFrameAnalyticsConfig config; + + public Response(DataFrameAnalyticsConfig config) { + this.config = config; + } + + Response() {} + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + config = new DataFrameAnalyticsConfig(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + config.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + config.toXContent(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(config, response.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } + } + + public static class RequestBuilder extends MasterNodeOperationRequestBuilder { + + protected RequestBuilder(ElasticsearchClient client, PutDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/RunAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/RunAnalyticsAction.java deleted file mode 100644 index 0e2a4eb15eb04..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/RunAnalyticsAction.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.action; - -import org.elasticsearch.action.Action; -import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.ActionRequestBuilder; -import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.support.master.AcknowledgedResponse; -import org.elasticsearch.client.ElasticsearchClient; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.ToXContentObject; -import org.elasticsearch.common.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Objects; - -public class RunAnalyticsAction extends Action { - - public static final RunAnalyticsAction INSTANCE = new RunAnalyticsAction(); - public static final String NAME = "cluster:admin/xpack/ml/analytics/run"; - - private RunAnalyticsAction() { - super(NAME); - } - - @Override - public AcknowledgedResponse newResponse() { - return new AcknowledgedResponse(); - } - - public static class Request extends ActionRequest implements ToXContentObject { - - private String index; - - public Request(String index) { - this.index = index; - } - - public Request(StreamInput in) throws IOException { - readFrom(in); - } - - public Request() { - } - - public String getIndex() { - return index; - } - - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void readFrom(StreamInput in) throws IOException { - super.readFrom(in); - index = in.readString(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeString(index); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("index", index); - return builder; - } - - @Override - public int hashCode() { - return Objects.hash(index); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - if (obj == null || obj.getClass() != getClass()) { - return false; - } - RunAnalyticsAction.Request other = (RunAnalyticsAction.Request) obj; - return Objects.equals(index, other.index); - } - - @Override - public String toString() { - return Strings.toString(this); - } - } - - static class RequestBuilder extends ActionRequestBuilder { - - RequestBuilder(ElasticsearchClient client, RunAnalyticsAction action) { - super(client, action, new Request()); - } - } - -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..2a3f7fbd008a2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.Version; +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class StartDataFrameAnalyticsAction extends Action { + + public static final StartDataFrameAnalyticsAction INSTANCE = new StartDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/start"; + + private StartDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public AcknowledgedResponse newResponse() { + return new AcknowledgedResponse(); + } + + public static class Request extends MasterNodeRequest implements ToXContentObject { + + private String id; + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public Request(StreamInput in) throws IOException { + readFrom(in); + } + + public Request() { + } + + public String getId() { + return id; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + id = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (id != null) { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + } + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + StartDataFrameAnalyticsAction.Request other = (StartDataFrameAnalyticsAction.Request) obj; + return Objects.equals(id, other.id); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client, StartDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } + + public static class TaskParams implements XPackPlugin.XPackPersistentTaskParams { + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, a -> new TaskParams((String) a[0])); + + public static TaskParams fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private String id; + + public TaskParams(String id) { + this.id = Objects.requireNonNull(id); + } + + public TaskParams(StreamInput in) throws IOException { + this.id = in.readString(); + } + + public String getId() { + return id; + } + + @Override + public String getWriteableName() { + return MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + // TODO Update to first released version + return Version.CURRENT; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java new file mode 100644 index 0000000000000..acdb21b44dadf --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ContextParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class DataFrameAnalysisConfig implements ToXContentObject, Writeable { + + public static ContextParser parser() { + return (p, c) -> new DataFrameAnalysisConfig(p.mapOrdered()); + } + + private final Map config; + + public DataFrameAnalysisConfig(Map config) { + this.config = Objects.requireNonNull(config); + if (config.size() != 1) { + throw ExceptionsHelper.badRequestException("A data frame analysis must specify exactly one analysis type"); + } + } + + public DataFrameAnalysisConfig(StreamInput in) throws IOException { + config = in.readMap(); + } + + public Map asMap() { + return config; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(config); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.map(config); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataFrameAnalysisConfig that = (DataFrameAnalysisConfig) o; + return Objects.equals(config, that.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java new file mode 100644 index 0000000000000..028e4347ba8f0 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -0,0 +1,163 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { + + public static final String TYPE = "data_frame_analytics_config"; + + public static final ParseField ID = new ParseField("id"); + public static final ParseField SOURCE = new ParseField("source"); + public static final ParseField DEST = new ParseField("dest"); + public static final ParseField ANALYSES = new ParseField("analyses"); + public static final ParseField CONFIG_TYPE = new ParseField("config_type"); + + public static final ObjectParser STRICT_PARSER = createParser(false); + public static final ObjectParser LENIENT_PARSER = createParser(true); + + public static ObjectParser createParser(boolean ignoreUnknownFields) { + ObjectParser parser = new ObjectParser<>(TYPE, ignoreUnknownFields, Builder::new); + + parser.declareString((c, s) -> {}, CONFIG_TYPE); + parser.declareString(Builder::setId, ID); + parser.declareString(Builder::setSource, SOURCE); + parser.declareString(Builder::setDest, DEST); + parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES); + return parser; + } + + private final String id; + private final String source; + private final String dest; + private final List analyses; + + public DataFrameAnalyticsConfig(String id, String source, String dest, List analyses) { + this.id = ExceptionsHelper.requireNonNull(id, ID); + this.source = ExceptionsHelper.requireNonNull(source, SOURCE); + this.dest = ExceptionsHelper.requireNonNull(dest, DEST); + this.analyses = ExceptionsHelper.requireNonNull(analyses, ANALYSES); + if (analyses.isEmpty()) { + throw new ElasticsearchParseException("One or more analyses are required"); + } + // TODO Add support for multiple analyses + if (analyses.size() > 1) { + throw new UnsupportedOperationException("Does not yet support multiple analyses"); + } + } + + public DataFrameAnalyticsConfig(StreamInput in) throws IOException { + id = in.readString(); + source = in.readString(); + dest = in.readString(); + analyses = in.readList(DataFrameAnalysisConfig::new); + } + + public String getId() { + return id; + } + + public String getSource() { + return source; + } + + public String getDest() { + return dest; + } + + public List getAnalyses() { + return analyses; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ID.getPreferredName(), id); + builder.field(SOURCE.getPreferredName(), source); + builder.field(DEST.getPreferredName(), dest); + builder.field(ANALYSES.getPreferredName(), analyses); + if (params.paramAsBoolean(ToXContentParams.INCLUDE_TYPE, false)) { + builder.field(CONFIG_TYPE.getPreferredName(), TYPE); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeString(source); + out.writeString(dest); + out.writeList(analyses); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsConfig other = (DataFrameAnalyticsConfig) o; + return Objects.equals(id, other.id) + && Objects.equals(source, other.source) + && Objects.equals(dest, other.dest) + && Objects.equals(analyses, other.analyses); + } + + @Override + public int hashCode() { + return Objects.hash(id, source, dest, analyses); + } + + public static String documentId(String id) { + return TYPE + "-" + id; + } + + public static class Builder { + + private String id; + private String source; + private String dest; + private List analyses; + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, ID); + } + + public void setSource(String source) { + this.source = ExceptionsHelper.requireNonNull(source, SOURCE); + } + + public void setDest(String dest) { + this.dest = ExceptionsHelper.requireNonNull(dest, DEST); + } + + public void setAnalyses(List analyses) { + this.analyses = ExceptionsHelper.requireNonNull(analyses, ANALYSES); + } + + public DataFrameAnalyticsConfig build() { + return new DataFrameAnalyticsConfig(id, source, dest, analyses); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java new file mode 100644 index 0000000000000..d40df259eec57 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum DataFrameAnalyticsState implements Writeable { + + STARTED, REINDEXING, ANALYZING, STOPPING, STOPPED; + + public static DataFrameAnalyticsState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static DataFrameAnalyticsState fromStream(StreamInput in) throws IOException { + return in.readEnum(DataFrameAnalyticsState.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java new file mode 100644 index 0000000000000..5d9b7ba756190 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.xpack.core.ml.MlTasks; + +import java.io.IOException; +import java.util.Objects; + +public class DataFrameAnalyticsTaskState implements PersistentTaskState { + + private static ParseField STATE = new ParseField("state"); + private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); + + private final DataFrameAnalyticsState state; + private final long allocationId; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, + a -> new DataFrameAnalyticsTaskState((DataFrameAnalyticsState) a[0], (long) a[1])); + + static { + PARSER.declareField(ConstructingObjectParser.constructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return DataFrameAnalyticsState.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, STATE, ObjectParser.ValueType.STRING); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), ALLOCATION_ID); + } + + public static DataFrameAnalyticsTaskState fromXContent(XContentParser parser) { + try { + return PARSER.parse(parser, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public DataFrameAnalyticsTaskState(DataFrameAnalyticsState state, long allocationId) { + this.state = Objects.requireNonNull(state); + this.allocationId = allocationId; + } + + public DataFrameAnalyticsState getState() { + return state; + } + + public boolean isStatusStale(PersistentTasksCustomMetaData.PersistentTask task) { + return allocationId != task.getAllocationId(); + } + + @Override + public String getWriteableName() { + return MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + state.writeTo(out); + out.writeLong(allocationId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(STATE.getPreferredName(), state.toString()); + builder.field(ALLOCATION_ID.getPreferredName(), allocationId); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataFrameAnalyticsTaskState that = (DataFrameAnalyticsTaskState) o; + return allocationId == that.allocationId && + state == that.state; + } + + @Override + public int hashCode() { + return Objects.hash(state, allocationId); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index d51a8f10e4a5a..6f60a549ae707 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -144,6 +145,7 @@ public static XContentBuilder configMapping() throws IOException { addJobConfigFields(builder); addDatafeedConfigFields(builder); + addDataFrameAnalyticsFields(builder); builder.endObject() .endObject() @@ -386,6 +388,32 @@ public static void addDatafeedConfigFields(XContentBuilder builder) throws IOExc .endObject(); } + public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws IOException { + builder.startObject(DataFrameAnalyticsConfig.ID.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsConfig.SOURCE.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsConfig.DEST.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsConfig.ANALYSES.getPreferredName()) + .startObject(PROPERTIES) + .startObject("outlier_detection") + .startObject(PROPERTIES) + .startObject("number_neighbours") + .field(TYPE, INTEGER) + .endObject() + .startObject("method") + .field(TYPE, KEYWORD) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + } + /** * Creates a default mapping which has a dynamic template that * treats all dynamically added fields as keywords. This is needed diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 333b87b0c294f..62e9f78e1c826 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -8,6 +8,7 @@ import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -256,6 +257,14 @@ public final class ReservedFieldNames { ChunkingConfig.MODE_FIELD.getPreferredName(), ChunkingConfig.TIME_SPAN_FIELD.getPreferredName(), + DataFrameAnalyticsConfig.ID.getPreferredName(), + DataFrameAnalyticsConfig.SOURCE.getPreferredName(), + DataFrameAnalyticsConfig.DEST.getPreferredName(), + DataFrameAnalyticsConfig.ANALYSES.getPreferredName(), + "outlier_detection", + "method", + "number_neighbours", + ElasticsearchMappings.CONFIG_TYPE }; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java index b66fd948a5a83..2d4c636172eca 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java @@ -10,7 +10,7 @@ /** * Interface for classes that write arrays of strings to the - * Ml analytics processes. + * Ml data frame analytics processes. */ public interface RecordWriter { /** diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index 47c0d4f64f96f..320eace983590 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -10,6 +10,7 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.common.ParseField; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -34,6 +35,14 @@ public static ResourceAlreadyExistsException datafeedAlreadyExists(String datafe return new ResourceAlreadyExistsException(Messages.getMessage(Messages.DATAFEED_ID_ALREADY_TAKEN, datafeedId)); } + public static ResourceNotFoundException missingDataFrameAnalytics(String id) { + return new ResourceNotFoundException("No known data frame analytics with id [{}]", id); + } + + public static ResourceAlreadyExistsException dataFrameAnalyticsAlreadyExists(String id) { + return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id); + } + public static ElasticsearchException serverError(String msg) { return new ElasticsearchException(msg); } @@ -86,4 +95,8 @@ public static T requireNonNull(T obj, String paramName) { } return obj; } + + public static T requireNonNull(T obj, ParseField paramName) { + return requireNonNull(obj, paramName.getPreferredName()); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..e3b7262095abb --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.test.AbstractStreamableTestCase; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Response; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; + +import java.util.ArrayList; +import java.util.List; + +public class GetDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { + + @Override + protected Response createTestInstance() { + int listSize = randomInt(10); + List analytics = new ArrayList<>(listSize); + for (int j = 0; j < listSize; j++) { + analytics.add(DataFrameAnalyticsConfigTests.createRandom(DataFrameAnalyticsConfigTests.randomValidId())); + } + return new Response(new QueryPage<>(analytics, analytics.size(), Response.RESULTS_FIELD)); + } + + @Override + protected Response createBlankInstance() { + return new Response(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..48381526b8c4c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Request; +import org.elasticsearch.xpack.core.ml.action.util.PageParams; + +public class GetDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(); + request.setResourceId(randomAlphaOfLength(20)); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java new file mode 100644 index 0000000000000..65c51faa9157e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; + +import java.util.ArrayList; +import java.util.List; + +public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + int listSize = randomInt(10); + List analytics = new ArrayList<>(listSize); + for (int j = 0; j < listSize; j++) { + Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), + randomFrom(DataFrameAnalyticsState.values()), null, randomAlphaOfLength(20)); + analytics.add(stats); + } + return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java new file mode 100644 index 0000000000000..8db7d8db7877c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Request; +import org.elasticsearch.xpack.core.ml.action.util.PageParams; + +public class GetDataFrameAnalyticsStatsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java new file mode 100644 index 0000000000000..2d899b7fb2d44 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractStreamableXContentTestCase; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Request; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.junit.Before; + +public class PutDataFrameAnalyticsActionRequestTests extends AbstractStreamableXContentTestCase { + + private String id; + + @Before + public void setUpId() { + id = DataFrameAnalyticsConfigTests.randomValidId(); + } + + @Override + protected Request createTestInstance() { + return new Request(DataFrameAnalyticsConfigTests.createRandom(id)); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected Request createBlankInstance() { + return new Request(); + } + + @Override + protected Request doParseInstance(XContentParser parser) { + return Request.parseRequest(id, parser); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..011044fb96eef --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.test.AbstractStreamableTestCase; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; + +public class PutDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { + + @Override + protected Response createTestInstance() { + return new Response(DataFrameAnalyticsConfigTests.createRandom(DataFrameAnalyticsConfigTests.randomValidId())); + } + + @Override + protected Response createBlankInstance() { + return new Response(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..9875c87f5ef3c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.Request; + +public class StartDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request(randomAlphaOfLength(20)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfigTests.java new file mode 100644 index 0000000000000..a5dc889eea3db --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfigTests.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class DataFrameAnalysisConfigTests extends AbstractSerializingTestCase { + + @Override + protected DataFrameAnalysisConfig createTestInstance() { + return randomConfig(); + } + + @Override + protected DataFrameAnalysisConfig doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalysisConfig.parser().parse(parser, null); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataFrameAnalysisConfig::new; + } + + public static DataFrameAnalysisConfig randomConfig() { + Map configParams = new HashMap<>(); + int count = randomIntBetween(1, 5); + for (int i = 0; i < count; i++) { + if (randomBoolean()) { + configParams.put(randomAlphaOfLength(10), randomInt()); + } else { + configParams.put(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + } + Map config = new HashMap<>(); + config.put(randomAlphaOfLength(10), configParams); + return new DataFrameAnalysisConfig(config); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java new file mode 100644 index 0000000000000..dfb3e8ef36546 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { + + @Override + protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build(); + } + + @Override + protected DataFrameAnalyticsConfig createTestInstance() { + return createRandom(randomValidId()); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataFrameAnalyticsConfig::new; + } + + public static DataFrameAnalyticsConfig createRandom(String id) { + String source = randomAlphaOfLength(10); + String dest = randomAlphaOfLength(10); + List analyses = Collections.singletonList(DataFrameAnalysisConfigTests.randomConfig()); + return new DataFrameAnalyticsConfig(id, source, dest, analyses); + } + + public static String randomValidId() { + CodepointSetGenerator generator = new CodepointSetGenerator("abcdefghijklmnopqrstuvwxyz".toCharArray()); + return generator.ofCodePointsLength(random(), 10, 10); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java index 46d7c5b9e43da..00edb6bb6df2c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java @@ -29,6 +29,7 @@ public MlRestTestStateCleaner(Logger logger, RestClient adminClient) { public void clearMlMetadata() throws IOException { deleteAllDatafeeds(); deleteAllJobs(); + deleteAllDataFrameAnalytics(); // indices will be deleted by the ESRestTestCase class } @@ -91,4 +92,20 @@ private void deleteAllJobs() throws IOException { adminClient.performRequest(new Request("DELETE", "/_ml/anomaly_detectors/" + jobId)); } } + + private void deleteAllDataFrameAnalytics() throws IOException { + final Request analyticsRequest = new Request("GET", "/_ml/data_frame/analytics?size=10000"); + analyticsRequest.addParameter("filter_path", "data_frame_analytics"); + final Response analyticsResponse = adminClient.performRequest(analyticsRequest); + List> analytics = (List>) XContentMapValues.extractValue( + "data_frame_analytics", ESRestTestCase.entityAsMap(analyticsResponse)); + if (analytics == null) { + return; + } + + for (Map config : analytics) { + String id = (String) config.get("id"); + adminClient.performRequest(new Request("DELETE", "/_ml/data_frame/analytics/" + id)); + } + } } diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 43421b4591f0a..2ddcd3f276027 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -35,6 +35,17 @@ integTestRunner { 'ml/datafeeds_crud/Test put datafeed with invalid query', 'ml/datafeeds_crud/Test put datafeed with security headers in the body', 'ml/datafeeds_crud/Test update datafeed with missing id', + 'ml/data_frame_analytics_crud/Test put config with inconsistent body/param ids', + 'ml/data_frame_analytics_crud/Test put config with invalid id', + 'ml/data_frame_analytics_crud/Test put config with unknown top level field', + 'ml/data_frame_analytics_crud/Test put config with unknown field in outlier detection analysis', + 'ml/data_frame_analytics_crud/Test put config given missing source', + 'ml/data_frame_analytics_crud/Test put config given missing dest', + 'ml/data_frame_analytics_crud/Test put config given missing analyses', + 'ml/data_frame_analytics_crud/Test put config given empty analyses', + 'ml/data_frame_analytics_crud/Test put config given two analyses', + 'ml/data_frame_analytics_crud/Test get given missing analytics', + 'ml/data_frame_analytics_crud/Test delete given missing config', 'ml/delete_job_force/Test cannot force delete a non-existent job', 'ml/delete_model_snapshot/Test delete snapshot missing snapshotId', 'ml/delete_model_snapshot/Test delete snapshot missing job_id', @@ -84,6 +95,9 @@ integTestRunner { 'ml/post_data/Test POST data with invalid parameters', 'ml/preview_datafeed/Test preview missing datafeed', 'ml/revert_model_snapshot/Test revert model with invalid snapshotId', + 'ml/start_data_frame_analytics/Test start given missing config', + 'ml/start_data_frame_analytics/Test start given missing source index', + 'ml/start_data_frame_analytics/Test start given source index has no compatible fields', 'ml/start_stop_datafeed/Test start datafeed job, but not open', 'ml/start_stop_datafeed/Test start non existing datafeed', 'ml/start_stop_datafeed/Test stop non existing datafeed', diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 272415421426f..bf9893e7f74d4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -61,6 +61,7 @@ import org.elasticsearch.xpack.core.ml.action.CloseJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction; +import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction; import org.elasticsearch.xpack.core.ml.action.DeleteExpiredDataAction; import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction; @@ -75,6 +76,8 @@ import org.elasticsearch.xpack.core.ml.action.GetCalendarEventsAction; import org.elasticsearch.xpack.core.ml.action.GetCalendarsAction; import org.elasticsearch.xpack.core.ml.action.GetCategoriesAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetFiltersAction; @@ -93,12 +96,13 @@ import org.elasticsearch.xpack.core.ml.action.PostDataAction; import org.elasticsearch.xpack.core.ml.action.PreviewDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutCalendarAction; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutFilterAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; -import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; @@ -118,6 +122,7 @@ import org.elasticsearch.xpack.ml.action.TransportCloseJobAction; import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction; import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarEventAction; +import org.elasticsearch.xpack.ml.action.TransportDeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportDeleteDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportDeleteExpiredDataAction; import org.elasticsearch.xpack.ml.action.TransportDeleteFilterAction; @@ -132,6 +137,8 @@ import org.elasticsearch.xpack.ml.action.TransportGetCalendarEventsAction; import org.elasticsearch.xpack.ml.action.TransportGetCalendarsAction; import org.elasticsearch.xpack.ml.action.TransportGetCategoriesAction; +import org.elasticsearch.xpack.ml.action.TransportGetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.ml.action.TransportGetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.ml.action.TransportGetDatafeedsAction; import org.elasticsearch.xpack.ml.action.TransportGetDatafeedsStatsAction; import org.elasticsearch.xpack.ml.action.TransportGetFiltersAction; @@ -150,12 +157,13 @@ import org.elasticsearch.xpack.ml.action.TransportPostDataAction; import org.elasticsearch.xpack.ml.action.TransportPreviewDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportPutCalendarAction; +import org.elasticsearch.xpack.ml.action.TransportPutDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportPutDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportPutFilterAction; import org.elasticsearch.xpack.ml.action.TransportPutJobAction; import org.elasticsearch.xpack.ml.action.TransportRevertModelSnapshotAction; -import org.elasticsearch.xpack.ml.action.TransportRunAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportStopDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportUpdateCalendarJobAction; @@ -166,12 +174,14 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; -import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory; import org.elasticsearch.xpack.ml.datafeed.DatafeedJobBuilder; import org.elasticsearch.xpack.ml.datafeed.DatafeedManager; import org.elasticsearch.xpack.ml.datafeed.persistence.DatafeedConfigProvider; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; +import org.elasticsearch.xpack.ml.dataframe.process.NativeAnalyticsProcessFactory; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -197,7 +207,6 @@ import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction; import org.elasticsearch.xpack.ml.rest.RestFindFileStructureAction; import org.elasticsearch.xpack.ml.rest.RestMlInfoAction; -import org.elasticsearch.xpack.ml.rest.analytics.RestRunAnalyticsAction; import org.elasticsearch.xpack.ml.rest.RestSetUpgradeModeAction; import org.elasticsearch.xpack.ml.rest.calendar.RestDeleteCalendarAction; import org.elasticsearch.xpack.ml.rest.calendar.RestDeleteCalendarEventAction; @@ -215,6 +224,11 @@ import org.elasticsearch.xpack.ml.rest.datafeeds.RestStartDatafeedAction; import org.elasticsearch.xpack.ml.rest.datafeeds.RestStopDatafeedAction; import org.elasticsearch.xpack.ml.rest.datafeeds.RestUpdateDatafeedAction; +import org.elasticsearch.xpack.ml.rest.dataframe.RestDeleteDataFrameAnalyticsAction; +import org.elasticsearch.xpack.ml.rest.dataframe.RestGetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.ml.rest.dataframe.RestGetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.ml.rest.dataframe.RestPutDataFrameAnalyticsAction; +import org.elasticsearch.xpack.ml.rest.dataframe.RestStartDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.rest.filter.RestDeleteFilterAction; import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction; import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction; @@ -291,6 +305,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu private final SetOnce autodetectProcessManager = new SetOnce<>(); private final SetOnce datafeedManager = new SetOnce<>(); + private final SetOnce dataFrameAnalyticsManager = new SetOnce<>(); private final SetOnce memoryTracker = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { @@ -455,8 +470,13 @@ public Collection createComponents(Client client, ClusterService cluster // run node startup tasks autodetectProcessManager.onNodeStartup(); + // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, environment, threadPool, analyticsProcessFactory); + DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client); + DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(clusterService, client, + dataFrameAnalyticsConfigProvider, analyticsProcessManager); + this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager); return Arrays.asList( mlLifeCycleService, @@ -472,7 +492,8 @@ public Collection createComponents(Client client, ClusterService cluster auditor, new MlAssignmentNotifier(settings, auditor, threadPool, client, clusterService), memoryTracker, - analyticsProcessManager + analyticsProcessManager, + dataFrameAnalyticsConfigProvider ); } @@ -487,7 +508,8 @@ public List> getPersistentTasksExecutor(ClusterServic return Arrays.asList( new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get(), memoryTracker.get(), client), - new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(datafeedManager.get()) + new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(datafeedManager.get()), + new TransportStartDataFrameAnalyticsAction.TaskExecutor(dataFrameAnalyticsManager.get()) ); } @@ -559,8 +581,12 @@ public List getRestHandlers(Settings settings, RestController restC new RestGetCalendarEventsAction(settings, restController), new RestPostCalendarEventAction(settings, restController), new RestFindFileStructureAction(settings, restController), - new RestRunAnalyticsAction(settings, restController), - new RestSetUpgradeModeAction(settings, restController) + new RestSetUpgradeModeAction(settings, restController), + new RestGetDataFrameAnalyticsAction(settings, restController), + new RestGetDataFrameAnalyticsStatsAction(settings, restController), + new RestPutDataFrameAnalyticsAction(settings, restController), + new RestDeleteDataFrameAnalyticsAction(settings, restController), + new RestStartDataFrameAnalyticsAction(settings, restController) ); } @@ -619,8 +645,12 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(PostCalendarEventsAction.INSTANCE, TransportPostCalendarEventsAction.class), new ActionHandler<>(PersistJobAction.INSTANCE, TransportPersistJobAction.class), new ActionHandler<>(FindFileStructureAction.INSTANCE, TransportFindFileStructureAction.class), - new ActionHandler<>(RunAnalyticsAction.INSTANCE, TransportRunAnalyticsAction.class), - new ActionHandler<>(SetUpgradeModeAction.INSTANCE, TransportSetUpgradeModeAction.class) + new ActionHandler<>(SetUpgradeModeAction.INSTANCE, TransportSetUpgradeModeAction.class), + new ActionHandler<>(GetDataFrameAnalyticsAction.INSTANCE, TransportGetDataFrameAnalyticsAction.class), + new ActionHandler<>(GetDataFrameAnalyticsStatsAction.INSTANCE, TransportGetDataFrameAnalyticsStatsAction.class), + new ActionHandler<>(PutDataFrameAnalyticsAction.INSTANCE, TransportPutDataFrameAnalyticsAction.class), + new ActionHandler<>(DeleteDataFrameAnalyticsAction.INSTANCE, TransportDeleteDataFrameAnalyticsAction.class), + new ActionHandler<>(StartDataFrameAnalyticsAction.INSTANCE, TransportStartDataFrameAnalyticsAction.class) ); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/AbstractTransportGetResourcesAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/AbstractTransportGetResourcesAction.java new file mode 100644 index 0000000000000..56fa331230f03 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/AbstractTransportGetResourcesAction.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.ml.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.ml.utils.MlIndicesUtils; + +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public abstract class AbstractTransportGetResourcesAction> + extends HandledTransportAction { + + private static final String ALL = "_all"; + + private Client client; + + protected AbstractTransportGetResourcesAction(String actionName, TransportService transportService, ActionFilters actionFilters, + Supplier request, Client client) { + super(actionName, transportService, actionFilters, request); + this.client = Objects.requireNonNull(client); + } + + protected void searchResources(AbstractGetResourcesRequest request, ActionListener> listener) { + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .sort(request.getResourceIdField()) + .from(request.getPageParams().getFrom()) + .size(request.getPageParams().getSize()) + .query(buildQuery(request)); + + SearchRequest searchRequest = new SearchRequest(getIndices()) + .indicesOptions(MlIndicesUtils.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS)) + .source(sourceBuilder); + + executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest, new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + List docs = new ArrayList<>(); + for (SearchHit hit : response.getHits().getHits()) { + BytesReference docSource = hit.getSourceRef(); + try (InputStream stream = docSource.streamInput(); + XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser( + NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) { + docs.add(parse(parser)); + } catch (IOException e) { + this.onFailure(e); + } + } + + if (docs.isEmpty() && isConcreteMatch(request.getResourceId())) { + listener.onFailure(notFoundException(request.getResourceId())); + } else { + listener.onResponse(new QueryPage<>(docs, docs.size(), getResultsField())); + } + } + + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }, + client::search); + } + + private QueryBuilder buildQuery(AbstractGetResourcesRequest request) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + if (isMatchAll(request.getResourceId()) == false) { + boolQuery.filter(QueryBuilders.wildcardQuery(request.getResourceIdField(), request.getResourceId())); + } + QueryBuilder additionalQuery = additionalQuery(); + if (additionalQuery != null) { + boolQuery.filter(additionalQuery); + } + return boolQuery.hasClauses() ? boolQuery : QueryBuilders.matchAllQuery(); + } + + private static boolean isMatchAll(String resourceId) { + return Strings.isNullOrEmpty(resourceId) || ALL.equals(resourceId) || Regex.isMatchAllPattern(resourceId); + } + + private static boolean isConcreteMatch(String resourceId) { + return isMatchAll(resourceId) == false && Regex.isSimpleMatchPattern(resourceId) == false; + } + + @Nullable + protected QueryBuilder additionalQuery() { + return null; + } + + protected abstract ParseField getResultsField(); + + protected abstract String[] getIndices(); + + protected abstract Resource parse(XContentParser parser); + + protected abstract ResourceNotFoundException notFoundException(String resourceId); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..c14f8bef92c0e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.delete.DeleteAction; +import org.elasticsearch.action.delete.DeleteRequest; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +/** + * The action is a master node action to ensure it reads an up-to-date cluster + * state in order to determine whether there is a persistent task for the analytics + * to delete. + */ +public class TransportDeleteDataFrameAnalyticsAction + extends TransportMasterNodeAction { + + private final Client client; + + @Inject + public TransportDeleteDataFrameAnalyticsAction(TransportService transportService, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, Client client) { + super(DeleteDataFrameAnalyticsAction.NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, + DeleteDataFrameAnalyticsAction.Request::new); + this.client = client; + } + + @Override + protected String executor() { + return ThreadPool.Names.SAME; + } + + @Override + protected AcknowledgedResponse newResponse() { + return new AcknowledgedResponse(); + } + + @Override + protected void masterOperation(DeleteDataFrameAnalyticsAction.Request request, ClusterState state, + ActionListener listener) { + PersistentTasksCustomMetaData tasks = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + DataFrameAnalyticsState taskState = MlTasks.getDataFrameAnalyticsState(request.getId(), tasks); + if (taskState != DataFrameAnalyticsState.STOPPED) { + listener.onFailure(ExceptionsHelper.conflictStatusException("Cannot delete data frame analytics [{}] while its status is [{}]", + request.getId(), taskState)); + return; + } + + DeleteRequest deleteRequest = new DeleteRequest(AnomalyDetectorsIndex.configIndexName()); + deleteRequest.id(DataFrameAnalyticsConfig.documentId(request.getId())); + deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + executeAsyncWithOrigin(client, ML_ORIGIN, DeleteAction.INSTANCE, deleteRequest, ActionListener.wrap( + deleteResponse -> { + if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { + listener.onFailure(ExceptionsHelper.missingDataFrameAnalytics(request.getId())); + return; + } + assert deleteResponse.getResult() == DocWriteResponse.Result.DELETED; + listener.onResponse(new AcknowledgedResponse(true)); + }, + listener::onFailure + )); + } + + @Override + protected ClusterBlockException checkBlock(DeleteDataFrameAnalyticsAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..02083b6c7d45e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +public class TransportGetDataFrameAnalyticsAction extends AbstractTransportGetResourcesAction { + + @Inject + public TransportGetDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters, Client client) { + super(GetDataFrameAnalyticsAction.NAME, transportService, actionFilters, GetDataFrameAnalyticsAction.Request::new, client); + } + + @Override + protected ParseField getResultsField() { + return GetDataFrameAnalyticsAction.Response.RESULTS_FIELD; + } + + @Override + protected String[] getIndices() { + return new String[] { AnomalyDetectorsIndex.configIndexName() }; + } + + @Override + protected DataFrameAnalyticsConfig parse(XContentParser parser) { + return DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + } + + @Override + protected ResourceNotFoundException notFoundException(String resourceId) { + return ExceptionsHelper.missingDataFrameAnalytics(resourceId); + } + + @Override + protected void doExecute(Task task, GetDataFrameAnalyticsAction.Request request, + ActionListener listener) { + searchResources(request, ActionListener.wrap( + queryPage -> listener.onResponse(new GetDataFrameAnalyticsAction.Response(queryPage)), + listener::onFailure + )); + } + + @Nullable + protected QueryBuilder additionalQuery() { + return QueryBuilders.termQuery(DataFrameAnalyticsConfig.CONFIG_TYPE.getPreferredName(), DataFrameAnalyticsConfig.TYPE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java new file mode 100644 index 0000000000000..1a30200668d93 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; + +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class TransportGetDataFrameAnalyticsStatsAction + extends TransportMasterNodeAction { + + private final Client client; + + @Inject + public TransportGetDataFrameAnalyticsStatsAction(TransportService transportService, ClusterService clusterService, Client client, + ThreadPool threadPool, ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver) { + super(GetDataFrameAnalyticsStatsAction.NAME, transportService, clusterService, threadPool, actionFilters, + indexNameExpressionResolver, GetDataFrameAnalyticsStatsAction.Request::new); + this.client = client; + } + + @Override + protected String executor() { + return ThreadPool.Names.SAME; + } + + @Override + protected GetDataFrameAnalyticsStatsAction.Response newResponse() { + return new GetDataFrameAnalyticsStatsAction.Response(); + } + + @Override + protected void masterOperation(GetDataFrameAnalyticsStatsAction.Request request, ClusterState state, + ActionListener listener) throws Exception { + PersistentTasksCustomMetaData tasks = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + + ActionListener getResponseListener = ActionListener.wrap( + response -> { + List stats = new ArrayList(response.getResources().results().size()); + response.getResources().results().forEach(c -> stats.add(buildStats(c.getId(), tasks, state))); + listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>(stats, stats.size(), + GetDataFrameAnalyticsAction.Response.RESULTS_FIELD))); + }, + listener::onFailure + ); + + GetDataFrameAnalyticsAction.Request getRequest = new GetDataFrameAnalyticsAction.Request(); + getRequest.setResourceId(request.getId()); + getRequest.setPageParams(request.getPageParams()); + executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, getRequest, getResponseListener); + } + + private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, PersistentTasksCustomMetaData tasks, + ClusterState clusterState) { + PersistentTasksCustomMetaData.PersistentTask analyticsTask = MlTasks.getDataFrameAnalyticsTask(concreteAnalyticsId, tasks); + DataFrameAnalyticsState analyticsState = MlTasks.getDataFrameAnalyticsState(concreteAnalyticsId, tasks); + DiscoveryNode node = null; + String assignmentExplanation = null; + if (analyticsTask != null) { + node = clusterState.nodes().get(analyticsTask.getExecutorNode()); + assignmentExplanation = analyticsTask.getAssignment().getExplanation(); + } + return new GetDataFrameAnalyticsStatsAction.Response.Stats( + concreteAnalyticsId, analyticsState, node, assignmentExplanation); + } + + @Override + protected ClusterBlockException checkBlock(GetDataFrameAnalyticsStatsAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_READ); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..e96ca02ce4f89 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; + +import java.util.function.Supplier; + +public class TransportPutDataFrameAnalyticsAction + extends HandledTransportAction { + + private final XPackLicenseState licenseState; + private final DataFrameAnalyticsConfigProvider configProvider; + + @Inject + public TransportPutDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters, + XPackLicenseState licenseState, DataFrameAnalyticsConfigProvider configProvider) { + super(PutDataFrameAnalyticsAction.NAME, transportService, actionFilters, + (Supplier) PutDataFrameAnalyticsAction.Request::new); + this.licenseState = licenseState; + this.configProvider = configProvider; + } + + @Override + protected void doExecute(Task task, PutDataFrameAnalyticsAction.Request request, + ActionListener listener) { + if (licenseState.isMachineLearningAllowed() == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + + validateConfig(request.getConfig()); + configProvider.put(request.getConfig(), ActionListener.wrap( + indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())), + listener::onFailure + )); + } + + private void validateConfig(DataFrameAnalyticsConfig config) { + if (MlStrings.isValidId(config.getId()) == false) { + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.INVALID_ID, DataFrameAnalyticsConfig.ID, + config.getId())); + } + if (!MlStrings.hasValidLengthForId(config.getId())) { + throw ExceptionsHelper.badRequestException("id [{}] is too long; must not contain more than {} characters", config.getId(), + MlStrings.ID_LENGTH_LIMIT); + } + if (config.getSource().isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must be non-empty", DataFrameAnalyticsConfig.SOURCE); + } + if (config.getDest().isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must be non-empty", DataFrameAnalyticsConfig.DEST); + } + DataFrameAnalysesUtils.readAnalyses(config.getAnalyses()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..e45b16a6c6786 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -0,0 +1,192 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.persistent.PersistentTasksExecutor; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.function.Predicate; + +/** + * Starts the persistent task for running data frame analytics. + * + * TODO Add to the upgrade mode action + */ +public class TransportStartDataFrameAnalyticsAction + extends TransportMasterNodeAction { + + private final XPackLicenseState licenseState; + private final Client client; + private final PersistentTasksService persistentTasksService; + private final DataFrameAnalyticsConfigProvider configProvider; + + @Inject + public TransportStartDataFrameAnalyticsAction(TransportService transportService, Client client, ClusterService clusterService, + ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, + IndexNameExpressionResolver indexNameExpressionResolver, + PersistentTasksService persistentTasksService, + DataFrameAnalyticsConfigProvider configProvider) { + super(StartDataFrameAnalyticsAction.NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, + StartDataFrameAnalyticsAction.Request::new); + this.licenseState = licenseState; + this.client = client; + this.persistentTasksService = persistentTasksService; + this.configProvider = configProvider; + } + + @Override + protected String executor() { + // This api doesn't do heavy or blocking operations (just delegates PersistentTasksService), + // so we can do this on the network thread + return ThreadPool.Names.SAME; + } + + @Override + protected AcknowledgedResponse newResponse() { + return new AcknowledgedResponse(); + } + + @Override + protected ClusterBlockException checkBlock(StartDataFrameAnalyticsAction.Request request, ClusterState state) { + // We only delegate here to PersistentTasksService, but if there is a metadata writeblock, + // then delegating to PersistentTasksService doesn't make a whole lot of sense, + // because PersistentTasksService will then fail. + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } + + @Override + protected void masterOperation(StartDataFrameAnalyticsAction.Request request, ClusterState state, + ActionListener listener) { + if (licenseState.isMachineLearningAllowed() == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + + StartDataFrameAnalyticsAction.TaskParams taskParams = new StartDataFrameAnalyticsAction.TaskParams(request.getId()); + + // Wait for analytics to be started + ActionListener> waitForAnalyticsToStart = + new ActionListener>() { + @Override + public void onResponse(PersistentTasksCustomMetaData.PersistentTask task) { + listener.onResponse(new AcknowledgedResponse(true)); + } + + @Override + public void onFailure(Exception e) { + if (e instanceof ResourceAlreadyExistsException) { + e = new ElasticsearchStatusException("Cannot open data frame analytics [" + request.getId() + + "] because it has already been opened", RestStatus.CONFLICT, e); + } + listener.onFailure(e); + } + }; + + // Start persistent task + ActionListener validatedConfigListener = ActionListener.wrap( + config -> persistentTasksService.sendStartRequest(MlTasks.dataFrameAnalyticsTaskId(request.getId()), + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, taskParams, waitForAnalyticsToStart), + listener::onFailure + ); + + // Validate config + ActionListener configListener = ActionListener.wrap( + config -> DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), config.getId(), config.getSource(), + ActionListener.wrap(dataExtractorFactory -> validatedConfigListener.onResponse(config), listener::onFailure)), + listener::onFailure + ); + + // Get config + configProvider.get(request.getId(), configListener); + } + + public static class DataFrameAnalyticsTask extends AllocatedPersistentTask { + + private final StartDataFrameAnalyticsAction.TaskParams taskParams; + + public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map headers, + StartDataFrameAnalyticsAction.TaskParams taskParams) { + super(id, type, action, "data_frame_analytics-" + taskParams.getId(), parentTask, headers); + this.taskParams = Objects.requireNonNull(taskParams); + } + + public StartDataFrameAnalyticsAction.TaskParams getParams() { + return taskParams; + } + } + + public static class TaskExecutor extends PersistentTasksExecutor { + + private final DataFrameAnalyticsManager manager; + + public TaskExecutor(DataFrameAnalyticsManager manager) { + super(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, MachineLearning.UTILITY_THREAD_POOL_NAME); + this.manager = Objects.requireNonNull(manager); + } + + @Override + protected AllocatedPersistentTask createTask( + long id, String type, String action, TaskId parentTaskId, + PersistentTasksCustomMetaData.PersistentTask persistentTask, + Map headers) { + return new DataFrameAnalyticsTask(id, type, action, parentTaskId, headers, persistentTask.getParams()); + } + + @Override + protected DiscoveryNode selectLeastLoadedNode(ClusterState clusterState, Predicate selector) { + // For starters, let's just select the least loaded ML node + // TODO implement memory-based load balancing + return super.selectLeastLoadedNode(clusterState, MachineLearning::isMlNode); + } + + @Override + protected void nodeOperation(AllocatedPersistentTask task, StartDataFrameAnalyticsAction.TaskParams params, + PersistentTaskState state) { + DataFrameAnalyticsTaskState startedState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.STARTED, + task.getAllocationId()); + task.updatePersistentTaskState(startedState, ActionListener.wrap( + response -> manager.execute((DataFrameAnalyticsTask) task), + task::markAsFailed + )); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalysis.java deleted file mode 100644 index 81062f9795040..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalysis.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.dataframe; - -import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.xcontent.ToXContentObject; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; - -import java.io.IOException; - -public class DataFrameAnalysis implements ToXContentObject { - - private static final ParseField NAME = new ParseField("name"); - - private final String name; - - public DataFrameAnalysis(String name) { - this.name = ExceptionsHelper.requireNonNull(name, NAME.getPreferredName()); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(NAME.getPreferredName(), name); - builder.endObject(); - return builder; - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsFields.java similarity index 79% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameFields.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsFields.java index 164b7888a6ffe..eeb3a8badce39 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsFields.java @@ -5,9 +5,9 @@ */ package org.elasticsearch.xpack.ml.dataframe; -public final class DataFrameFields { +public final class DataFrameAnalyticsFields { public static final String ID = "_id_copy"; - private DataFrameFields() {} + private DataFrameAnalyticsFields() {} } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java similarity index 50% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index cb9e25504a8e4..72b1319689820 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -3,28 +3,21 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.action; +package org.elasticsearch.xpack.ml.dataframe; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.admin.indices.create.CreateIndexAction; import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; -import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; import org.elasticsearch.client.Client; -import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.MappingMetaData; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.collect.ImmutableOpenMap; -import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.env.Environment; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexSortConfig; import org.elasticsearch.index.reindex.BulkByScrollResponse; @@ -32,14 +25,12 @@ import org.elasticsearch.index.reindex.ReindexRequest; import org.elasticsearch.script.Script; import org.elasticsearch.search.sort.SortOrder; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractorFactory; -import org.elasticsearch.xpack.ml.dataframe.DataFrameFields; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; import java.util.Arrays; @@ -47,16 +38,9 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Supplier; +import java.util.Objects; -public class TransportRunAnalyticsAction extends HandledTransportAction { - - private final TransportService transportService; - private final ThreadPool threadPool; - private final Client client; - private final ClusterService clusterService; - private final Environment environment; - private final AnalyticsProcessManager analyticsProcessManager; +public class DataFrameAnalyticsManager { /** * Unfortunately, getting the settings of an index include internal settings that should @@ -70,67 +54,96 @@ public class TransportRunAnalyticsAction extends HandledTransportAction) RunAnalyticsAction.Request::new); - this.transportService = transportService; - this.threadPool = threadPool; - this.client = client; - this.clusterService = clusterService; - this.environment = environment; - this.analyticsProcessManager = analyticsProcessManager; + private final ClusterService clusterService; + private final Client client; + private final DataFrameAnalyticsConfigProvider configProvider; + private final AnalyticsProcessManager processManager; + + public DataFrameAnalyticsManager(ClusterService clusterService, Client client, DataFrameAnalyticsConfigProvider configProvider, + AnalyticsProcessManager processManager) { + this.clusterService = Objects.requireNonNull(clusterService); + this.client = Objects.requireNonNull(client); + this.configProvider = Objects.requireNonNull(configProvider); + this.processManager = Objects.requireNonNull(processManager); } - @Override - protected void doExecute(Task task, RunAnalyticsAction.Request request, ActionListener listener) { - DiscoveryNode localNode = clusterService.localNode(); - if (MachineLearning.isMlNode(localNode)) { - reindexDataframeAndStartAnalysis(request.getIndex(), listener); - return; - } + public void execute(DataFrameAnalyticsTask task) { + ActionListener reindexingStateListener = ActionListener.wrap( + config -> reindexDataframeAndStartAnalysis(task, config), + e -> task.markAsFailed(e) + ); - ClusterState clusterState = clusterService.state(); - for (DiscoveryNode node : clusterState.getNodes()) { - if (MachineLearning.isMlNode(node)) { - transportService.sendRequest(node, actionName, request, - new ActionListenerResponseHandler<>(listener, inputStream -> { - AcknowledgedResponse response = new AcknowledgedResponse(); - response.readFrom(inputStream); - return response; - })); - return; - } - } - listener.onFailure(ExceptionsHelper.badRequestException("No ML node to run on")); + // Update task state to REINDEXING + ActionListener configListener = ActionListener.wrap( + config -> { + DataFrameAnalyticsTaskState reindexingState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.REINDEXING, + task.getAllocationId()); + task.updatePersistentTaskState(reindexingState, ActionListener.wrap( + updatedTask -> reindexingStateListener.onResponse(config), + reindexingStateListener::onFailure + )); + }, + reindexingStateListener::onFailure + ); + + // Retrieve configuration + configProvider.get(task.getParams().getId(), configListener); } - private void reindexDataframeAndStartAnalysis(String index, ActionListener listener) { - final String destinationIndex = index + "_copy"; + private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) { + // Reindexing is complete; start analytics + ActionListener refreshListener = ActionListener.wrap( + refreshResponse -> startAnalytics(task, config), + task::markAsFailed + ); + // Refresh to ensure copied index is fully searchable ActionListener reindexCompletedListener = ActionListener.wrap( - bulkResponse -> { - client.execute(RefreshAction.INSTANCE, new RefreshRequest(destinationIndex), ActionListener.wrap( - refreshResponse -> { - runPipelineAnalytics(destinationIndex, listener); - }, listener::onFailure - )); - }, listener::onFailure + bulkResponse -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest()), refreshListener), + e -> task.markAsFailed(e) ); + // Reindex ActionListener copyIndexCreatedListener = ActionListener.wrap( createIndexResponse -> { ReindexRequest reindexRequest = new ReindexRequest(); - reindexRequest.setSourceIndices(index); - reindexRequest.setDestIndex(destinationIndex); - reindexRequest.setScript(new Script("ctx._source." + DataFrameFields.ID + " = ctx._id")); + reindexRequest.setSourceIndices(config.getSource()); + reindexRequest.setDestIndex(config.getDest()); + reindexRequest.setScript(new Script("ctx._source." + DataFrameAnalyticsFields.ID + " = ctx._id")); client.execute(ReindexAction.INSTANCE, reindexRequest, reindexCompletedListener); - }, listener::onFailure + }, + reindexCompletedListener::onFailure ); - createDestinationIndex(index, destinationIndex, copyIndexCreatedListener); + createDestinationIndex(config.getSource(), config.getDest(), copyIndexCreatedListener); + } + + private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) { + // Update state to ANALYZING and start process + ActionListener dataExtractorFactoryListener = ActionListener.wrap( + dataExtractorFactory -> { + DataFrameAnalyticsTaskState analyzingState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.ANALYZING, + task.getAllocationId()); + task.updatePersistentTaskState(analyzingState, ActionListener.wrap( + updatedTask -> processManager.runJob(config, dataExtractorFactory, + error -> { + if (error != null) { + task.markAsFailed(error); + } else { + task.markAsCompleted(); + } + }), + task::markAsFailed + )); + }, + e -> task.markAsFailed(e) + ); + + // TODO This could fail with errors. In that case we get stuck with the copied index. + // We could delete the index in case of failure or we could try building the factory before reindexing + // to catch the error early on. + DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), config.getId(), config.getDest(), + dataExtractorFactoryListener); } private void createDestinationIndex(String sourceIndex, String destinationIndex, ActionListener listener) { @@ -140,14 +153,9 @@ private void createDestinationIndex(String sourceIndex, String destinationIndex, return; } - if (indexMetaData.getMappings().size() != 1) { - listener.onFailure(ExceptionsHelper.badRequestException("Does not support indices with multiple types")); - return; - } - Settings.Builder settingsBuilder = Settings.builder().put(indexMetaData.getSettings()); INTERNAL_SETTINGS.stream().forEach(settingsBuilder::remove); - settingsBuilder.put(IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey(), DataFrameFields.ID); + settingsBuilder.put(IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey(), DataFrameAnalyticsFields.ID); settingsBuilder.put(IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey(), SortOrder.ASC); CreateIndexRequest createIndexRequest = new CreateIndexRequest(destinationIndex, settingsBuilder.build()); @@ -161,25 +169,8 @@ private static void addDestinationIndexMappings(IndexMetaData indexMetaData, Cre Map properties = (Map) mappingsAsMap.get("properties"); Map idCopyMapping = new HashMap<>(); idCopyMapping.put("type", "keyword"); - properties.put(DataFrameFields.ID, idCopyMapping); + properties.put(DataFrameAnalyticsFields.ID, idCopyMapping); createIndexRequest.mapping(mappings.keysIt().next(), mappingsAsMap); } - - private void runPipelineAnalytics(String index, ActionListener listener) { - String jobId = "ml-analytics-" + index; - - ActionListener dataExtractorFactoryListener = ActionListener.wrap( - dataExtractorFactory -> { - analyticsProcessManager.runJob(jobId, dataExtractorFactory); - listener.onResponse(new AcknowledgedResponse(true)); - }, - listener::onFailure - ); - - // TODO This could fail with errors. In that case we get stuck with the copied index. - // We could delete the index in case of failure or we could try building the factory before reindexing - // to catch the error early on. - DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), index, dataExtractorFactoryListener); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/AbstractDataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/AbstractDataFrameAnalysis.java new file mode 100644 index 0000000000000..90bcc839bb361 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/AbstractDataFrameAnalysis.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.analyses; + +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; + +public abstract class AbstractDataFrameAnalysis implements DataFrameAnalysis { + + private static final String NAME = "name"; + private static final String PARAMETERS = "parameters"; + + @Override + public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NAME, getType()); + builder.field(PARAMETERS, getParams()); + builder.endObject(); + return builder; + } + + protected abstract Map getParams(); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtils.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtils.java new file mode 100644 index 0000000000000..5151d0c0c6e8d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtils.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.analyses; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public final class DataFrameAnalysesUtils { + + private static final Map factories; + + static { + factories = new HashMap<>(); + factories.put(DataFrameAnalysis.Type.OUTLIER_DETECTION, new OutlierDetection.Factory()); + } + + private DataFrameAnalysesUtils() {} + + public static List readAnalyses(List analyses) { + return analyses.stream().map(DataFrameAnalysesUtils::readAnalysis).collect(Collectors.toList()); + } + + static DataFrameAnalysis readAnalysis(DataFrameAnalysisConfig config) { + Map configMap = config.asMap(); + DataFrameAnalysis.Type analysisType = DataFrameAnalysis.Type.fromString(configMap.keySet().iterator().next()); + DataFrameAnalysis.Factory factory = factories.get(analysisType); + Map analysisConfig = castAsMapAndCopy(analysisType, configMap.get(analysisType.toString())); + DataFrameAnalysis dataFrameAnalysis = factory.create(analysisConfig); + if (analysisConfig.isEmpty() == false) { + throw new ElasticsearchParseException("Data frame analysis [{}] does not support one or more provided parameters {}", + analysisType, analysisConfig.keySet()); + } + return dataFrameAnalysis; + } + + private static Map castAsMapAndCopy(DataFrameAnalysis.Type analysisType, Object obj) { + try { + return new HashMap<>((Map) obj); + } catch (ClassCastException e) { + throw new ElasticsearchParseException("[{}] expected to be a map but was of type [{}]", analysisType, obj.getClass().getName()); + } + } + + @Nullable + static Integer readInt(DataFrameAnalysis.Type analysisType, Map config, String property) { + Object value = config.remove(property); + if (value == null) { + return null; + } + try { + return (int) value; + } catch (ClassCastException e) { + throw new ElasticsearchParseException("Property [{}] of analysis [{}] should be of type [Integer] but was [{}]", + property, analysisType, value.getClass().getSimpleName()); + } + } + + @Nullable + static String readString(DataFrameAnalysis.Type analysisType, Map config, String property) { + Object value = config.remove(property); + if (value == null) { + return null; + } + try { + return (String) value; + } catch (ClassCastException e) { + throw new ElasticsearchParseException("Property [{}] of analysis [{}] should be of type [String] but was [{}]", + property, analysisType, value.getClass().getSimpleName()); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java new file mode 100644 index 0000000000000..9fdd093fa324e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.analyses; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.common.xcontent.ToXContentObject; + +import java.util.Locale; +import java.util.Map; + +public interface DataFrameAnalysis extends ToXContentObject { + + enum Type { + OUTLIER_DETECTION; + + public static Type fromString(String value) { + try { + return Type.valueOf(value.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new ElasticsearchParseException("Unknown analysis type [{}]", value); + } + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } + + Type getType(); + + interface Factory { + + /** + * Creates a data frame analysis based on the specified map of maps config. + * + * @param config The configuration for the analysis + * + * Note: Implementations are responsible for removing the used configuration keys, so that after + * creation it can be verified that all configurations settings have been used. + */ + DataFrameAnalysis create(Map config); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java new file mode 100644 index 0000000000000..47f614ba658f6 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.analyses; + +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +public class OutlierDetection extends AbstractDataFrameAnalysis { + + public enum Method { + LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; + + public static Method fromString(String value) { + return Method.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } + + public static final String NUMBER_NEIGHBOURS = "number_neighbours"; + public static final String METHOD = "method"; + + private final Integer numberNeighbours; + private final Method method; + + public OutlierDetection(Integer numberNeighbours, Method method) { + this.numberNeighbours = numberNeighbours; + this.method = method; + } + + @Override + public Type getType() { + return Type.OUTLIER_DETECTION; + } + + @Override + protected Map getParams() { + Map params = new HashMap<>(); + if (numberNeighbours != null) { + params.put(NUMBER_NEIGHBOURS, numberNeighbours); + } + if (method != null) { + params.put(METHOD, method); + } + return params; + } + + static class Factory implements DataFrameAnalysis.Factory { + + @Override + public DataFrameAnalysis create(Map config) { + Integer numberNeighbours = DataFrameAnalysesUtils.readInt(Type.OUTLIER_DETECTION, config, NUMBER_NEIGHBOURS); + String method = DataFrameAnalysesUtils.readString(Type.OUTLIER_DETECTION, config, METHOD); + return new OutlierDetection(numberNeighbours, method == null ? null : Method.fromString(method)); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java similarity index 97% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractor.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 3d17ff7afd2c5..055a5e7d8dd64 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe; +package org.elasticsearch.xpack.ml.dataframe.extractor; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -22,6 +22,7 @@ import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.datafeed.extractor.ExtractorUtils; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsFields; import java.io.IOException; import java.util.ArrayList; @@ -82,7 +83,7 @@ public Optional> next() throws IOException { } protected List initScroll() throws IOException { - LOGGER.debug("[{}] Initializing scroll", "analytics"); + LOGGER.debug("[{}] Initializing scroll", context.jobId); SearchResponse searchResponse = executeSearchRequest(buildSearchRequest()); LOGGER.debug("[{}] Search response was obtained", context.jobId); return processSearchResponse(searchResponse); @@ -95,7 +96,7 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest private SearchRequestBuilder buildSearchRequest() { SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) .setScroll(SCROLL_TIMEOUT) - .addSort(DataFrameFields.ID, SortOrder.ASC) + .addSort(DataFrameAnalyticsFields.ID, SortOrder.ASC) .setIndices(context.indices) .setSize(context.scrollSize) .setQuery(context.query) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorContext.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java similarity index 95% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorContext.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java index 82de257ccaffd..f602a66221f7c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorContext.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorContext.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe; +package org.elasticsearch.xpack.ml.dataframe.extractor; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java similarity index 90% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactory.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 3080cd4b43a27..a9b225dee77ac 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe; +package org.elasticsearch.xpack.ml.dataframe.extractor; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -54,18 +54,20 @@ public class DataFrameDataExtractorFactory { } private final Client client; + private final String analyticsId; private final String index; private final ExtractedFields extractedFields; - private DataFrameDataExtractorFactory(Client client, String index, ExtractedFields extractedFields) { + private DataFrameDataExtractorFactory(Client client, String analyticsId, String index, ExtractedFields extractedFields) { this.client = Objects.requireNonNull(client); + this.analyticsId = Objects.requireNonNull(analyticsId); this.index = Objects.requireNonNull(index); this.extractedFields = Objects.requireNonNull(extractedFields); } public DataFrameDataExtractor newExtractor(boolean includeSource) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( - "ml-analytics-" + index, + analyticsId, extractedFields, Arrays.asList(index), QueryBuilders.matchAllQuery(), @@ -76,13 +78,14 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { return new DataFrameDataExtractor(client, context); } - public static void create(Client client, Map headers, String index, + public static void create(Client client, Map headers, String analyticsId, String index, ActionListener listener) { // Step 2. Contruct the factory and notify listener ActionListener fieldCapabilitiesHandler = ActionListener.wrap( fieldCapabilitiesResponse -> { - listener.onResponse(new DataFrameDataExtractorFactory(client, index, detectExtractedFields(fieldCapabilitiesResponse))); + listener.onResponse(new DataFrameDataExtractorFactory(client, analyticsId, index, + detectExtractedFields(index, fieldCapabilitiesResponse))); }, e -> { if (e instanceof IndexNotFoundException) { listener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " @@ -105,7 +108,7 @@ public static void create(Client client, Map headers, String ind } // Visible for testing - static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapabilitiesResponse) { + static ExtractedFields detectExtractedFields(String index, FieldCapabilitiesResponse fieldCapabilitiesResponse) { Set fields = fieldCapabilitiesResponse.get().keySet(); fields.removeAll(IGNORE_FIELDS); removeFieldsWithIncompatibleTypes(fields, fieldCapabilitiesResponse); @@ -115,7 +118,7 @@ static ExtractedFields detectExtractedFields(FieldCapabilitiesResponse fieldCapa ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse) .filterFields(ExtractedField.ExtractionMethod.DOC_VALUE); if (extractedFields.getAllFields().isEmpty()) { - throw ExceptionsHelper.badRequestException("No compatible fields could be detected"); + throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index [{}]", index); } return extractedFields; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java new file mode 100644 index 0000000000000..ed340d155a8fd --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.persistence; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; +import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; + +public class DataFrameAnalyticsConfigProvider { + + private static final Map TO_XCONTENT_PARAMS; + + static { + Map modifiable = new HashMap<>(); + modifiable.put(ToXContentParams.INCLUDE_TYPE, "true"); + TO_XCONTENT_PARAMS = Collections.unmodifiableMap(modifiable); + } + + private final Client client; + + public DataFrameAnalyticsConfigProvider(Client client) { + this.client = Objects.requireNonNull(client); + } + + public void put(DataFrameAnalyticsConfig config, ActionListener listener) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + config.toXContent(builder, new ToXContent.MapParams(TO_XCONTENT_PARAMS)); + IndexRequest indexRequest = new IndexRequest(AnomalyDetectorsIndex.configIndexName()) + .id(DataFrameAnalyticsConfig.documentId(config.getId())) + .opType(DocWriteRequest.OpType.CREATE) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(builder); + + executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, ActionListener.wrap( + listener::onResponse, + e -> { + if (e instanceof VersionConflictEngineException) { + listener.onFailure(ExceptionsHelper.dataFrameAnalyticsAlreadyExists(config.getId())); + } else { + listener.onFailure(e); + } + } + )); + } catch (IOException e) { + listener.onFailure(new ElasticsearchParseException("Failed to serialise data frame analytics with id [" + config.getId() + + "]")); + } + } + + public void get(String id, ActionListener listener) { + GetDataFrameAnalyticsAction.Request request = new GetDataFrameAnalyticsAction.Request(); + request.setResourceId(id); + executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, request, ActionListener.wrap( + response -> { + List analytics = response.getResources().results(); + if (analytics.size() != 1) { + listener.onFailure(ExceptionsHelper.badRequestException("Expected a single match for data frame analytics [{}] " + + "but got [{}]", id, analytics.size())); + } else { + listener.onResponse(analytics.get(0)); + } + }, + listener::onFailure + )); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java index 0d925ed439d3f..c5e361c3e1215 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java @@ -20,7 +20,7 @@ public interface AnalyticsProcess extends NativeProcess { void writeEndOfDataMessage() throws IOException; /** - * @return stream of analytics results. + * @return stream of data frame analytics results. */ Iterator readAnalyticsResults(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index ce186b9232da3..36507e3da292b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -8,7 +8,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalysis; +import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysis; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 987335d22f0a6..168662e5cf3c0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -13,17 +13,20 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.env.Environment; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalysis; -import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractor; -import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; +import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysis; import java.io.IOException; import java.util.List; import java.util.Objects; import java.util.Optional; import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; public class AnalyticsProcessManager { @@ -42,19 +45,20 @@ public AnalyticsProcessManager(Client client, Environment environment, ThreadPoo this.processFactory = Objects.requireNonNull(analyticsProcessFactory); } - public void runJob(String jobId, DataFrameDataExtractorFactory dataExtractorFactory) { + public void runJob(DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory, + Consumer finishHandler) { threadPool.generic().execute(() -> { DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); - AnalyticsProcess process = createProcess(jobId, createProcessConfig(dataExtractor)); + AnalyticsProcess process = createProcess(config.getId(), createProcessConfig(config, dataExtractor)); ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(client, dataExtractorFactory.newExtractor(true)); executorService.execute(() -> resultProcessor.process(process)); - executorService.execute(() -> processData(jobId, dataExtractor, process, resultProcessor)); + executorService.execute(() -> processData(config.getId(), dataExtractor, process, resultProcessor, finishHandler)); }); } private void processData(String jobId, DataFrameDataExtractor dataExtractor, AnalyticsProcess process, - AnalyticsResultProcessor resultProcessor) { + AnalyticsResultProcessor resultProcessor, Consumer finishHandler) { try { writeHeaderRecord(dataExtractor, process); writeDataRows(dataExtractor, process); @@ -66,13 +70,18 @@ private void processData(String jobId, DataFrameDataExtractor dataExtractor, Ana LOGGER.info("[{}] Result processor has completed", jobId); } catch (IOException e) { LOGGER.error(new ParameterizedMessage("[{}] Error writing data to the process", jobId), e); + // TODO Handle this failure by setting the task state to FAILED } finally { LOGGER.info("[{}] Closing process", jobId); try { process.close(); LOGGER.info("[{}] Closed process", jobId); + + // This results in marking the persistent task as complete + finishHandler.accept(null); } catch (IOException e) { LOGGER.error("[{}] Error closing data frame analyzer process", jobId); + finishHandler.accept(e); } } } @@ -119,15 +128,19 @@ private AnalyticsProcess createProcess(String jobId, AnalyticsProcessConfig anal ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, analyticsProcessConfig, executorService); if (process.isProcessAlive() == false) { - throw ExceptionsHelper.serverError("Failed to start analytics process"); + throw ExceptionsHelper.serverError("Failed to start data frame analytics process"); } return process; } - private AnalyticsProcessConfig createProcessConfig(DataFrameDataExtractor dataExtractor) { + private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); - AnalyticsProcessConfig config = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, - new ByteSizeValue(1, ByteSizeUnit.GB), 1, new DataFrameAnalysis("outliers")); - return config; + List dataFrameAnalyses = DataFrameAnalysesUtils.readAnalyses(config.getAnalyses()); + // TODO We will not need this assertion after we add support for multiple analyses + assert dataFrameAnalyses.size() == 1; + + AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, + new ByteSizeValue(1, ByteSizeUnit.GB), 1, dataFrameAnalyses.get(0)); + return processConfig; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index f6be7a4e78e36..cd6864f049740 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -14,7 +14,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.client.Client; import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import java.util.ArrayList; import java.util.Iterator; @@ -76,12 +76,12 @@ public void process(AnalyticsProcess process) { currentDataFrameRows = null; } } catch (Exception e) { - LOGGER.warn("Error processing analytics result", e); + LOGGER.warn("Error processing data frame analytics result", e); } } } catch (Exception e) { - LOGGER.error("Error parsing analytics output", e); + LOGGER.error("Error parsing data frame analytics output", e); } finally { completionLatch.countDown(); process.consumeAndCloseOutputStream(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java index 8148a431f0a67..14743b93dc424 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java @@ -62,7 +62,7 @@ public AnalyticsProcess createAnalyticsProcess(String jobId, AnalyticsProcessCon try { IOUtils.close(analyticsProcess); } catch (IOException ioe) { - LOGGER.error("Can't close analytics", ioe); + LOGGER.error("Can't close data frame analytics process", ioe); } throw e; } @@ -76,7 +76,7 @@ private void createNativeProcess(String jobId, AnalyticsProcessConfig analyticsP analyticsBuilder.build(); processPipes.connectStreams(PROCESS_STARTUP_TIMEOUT); } catch (IOException e) { - String msg = "Failed to launch analytics for job " + jobId; + String msg = "Failed to launch data frame analytics process for job " + jobId; LOGGER.error(msg); throw ExceptionsHelper.serverError(msg, e); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AbstractDataToProcessWriter.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AbstractDataToProcessWriter.java index dc9d77cd68784..799954619b315 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AbstractDataToProcessWriter.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/writer/AbstractDataToProcessWriter.java @@ -122,10 +122,10 @@ public void writeHeader() throws IOException { /** * Tokenize the field that has been configured for categorization, and store the resulting list of tokens in CSV - * format in the appropriate field of the record to be sent to the analytics. + * format in the appropriate field of the record to be sent to the process. * @param categorizationAnalyzer The analyzer to use to convert the categorization field to a list of tokens * @param categorizationFieldValue The value of the categorization field to be tokenized - * @param record The record to be sent to the analytics + * @param record The record to be sent to the process */ protected void tokenizeForCategorization(CategorizationAnalyzer categorizationAnalyzer, String categorizationFieldValue, String[] record) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestDeleteDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestDeleteDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..31a9ba690a9b2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestDeleteDataFrameAnalyticsAction.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.dataframe; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +public class RestDeleteDataFrameAnalyticsAction extends BaseRestHandler { + + public RestDeleteDataFrameAnalyticsAction(Settings settings, RestController controller) { + super(settings); + controller.registerHandler(RestRequest.Method.DELETE, MachineLearning.BASE_PATH + "data_frame/analytics/{" + + DataFrameAnalyticsConfig.ID.getPreferredName() + "}", this); + } + + @Override + public String getName() { + return "xpack_ml_delete_data_frame_analytics_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String id = restRequest.param(DataFrameAnalyticsConfig.ID.getPreferredName()); + DeleteDataFrameAnalyticsAction.Request request = new DeleteDataFrameAnalyticsAction.Request(id); + return channel -> client.execute(DeleteDataFrameAnalyticsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..6f2c89fd09c53 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.dataframe; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +public class RestGetDataFrameAnalyticsAction extends BaseRestHandler { + + public RestGetDataFrameAnalyticsAction(Settings settings, RestController controller) { + super(settings); + controller.registerHandler(RestRequest.Method.GET, MachineLearning.BASE_PATH + "data_frame/analytics", this); + controller.registerHandler(RestRequest.Method.GET, MachineLearning.BASE_PATH + "data_frame/analytics/{" + + DataFrameAnalyticsConfig.ID.getPreferredName() + "}", this); + } + + @Override + public String getName() { + return "xpack_ml_get_data_frame_analytics_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + GetDataFrameAnalyticsAction.Request request = new GetDataFrameAnalyticsAction.Request(); + String id = restRequest.param(DataFrameAnalyticsConfig.ID.getPreferredName()); + if (Strings.isNullOrEmpty(id) == false) { + request.setResourceId(id); + } + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + + return channel -> client.execute(GetDataFrameAnalyticsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java new file mode 100644 index 0000000000000..a2d2b1ca48e27 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.dataframe; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.core.ml.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +public class RestGetDataFrameAnalyticsStatsAction extends BaseRestHandler { + + public RestGetDataFrameAnalyticsStatsAction(Settings settings, RestController controller) { + super(settings); + controller.registerHandler(RestRequest.Method.GET, MachineLearning.BASE_PATH + "data_frame/analytics/_stats", this); + controller.registerHandler(RestRequest.Method.GET, MachineLearning.BASE_PATH + "data_frame/analytics/{" + + DataFrameAnalyticsConfig.ID.getPreferredName() + "}/_stats", this); + } + + @Override + public String getName() { + return "xpack_ml_get_data_frame_analytics_stats_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String id = restRequest.param(DataFrameAnalyticsConfig.ID.getPreferredName()); + GetDataFrameAnalyticsStatsAction.Request request = new GetDataFrameAnalyticsStatsAction.Request(); + if (Strings.isNullOrEmpty(id) == false) { + request.setId(id); + } + if (restRequest.hasParam(PageParams.FROM.getPreferredName()) || restRequest.hasParam(PageParams.SIZE.getPreferredName())) { + request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), + restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); + } + + return channel -> client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestPutDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestPutDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..e2422c6cdeba9 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestPutDataFrameAnalyticsAction.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.dataframe; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +public class RestPutDataFrameAnalyticsAction extends BaseRestHandler { + + public RestPutDataFrameAnalyticsAction(Settings settings, RestController controller) { + super(settings); + controller.registerHandler(RestRequest.Method.PUT, MachineLearning.BASE_PATH + "data_frame/analytics/{" + + DataFrameAnalyticsConfig.ID.getPreferredName() + "}", this); + } + + @Override + public String getName() { + return "xpack_ml_put_data_frame_analytics_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String id = restRequest.param(DataFrameAnalyticsConfig.ID.getPreferredName()); + XContentParser parser = restRequest.contentParser(); + PutDataFrameAnalyticsAction.Request putRequest = PutDataFrameAnalyticsAction.Request.parseRequest(id, parser); + putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout())); + + return channel -> client.execute(PutDataFrameAnalyticsAction.INSTANCE, putRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/analytics/RestRunAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStartDataFrameAnalyticsAction.java similarity index 50% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/analytics/RestRunAnalyticsAction.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStartDataFrameAnalyticsAction.java index feadce6ff302d..035822429cc29 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/analytics/RestRunAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStartDataFrameAnalyticsAction.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.rest.analytics; +package org.elasticsearch.xpack.ml.rest.dataframe; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.settings.Settings; @@ -11,28 +11,30 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.core.ml.action.RunAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.ml.MachineLearning; import java.io.IOException; -public class RestRunAnalyticsAction extends BaseRestHandler { +public class RestStartDataFrameAnalyticsAction extends BaseRestHandler { - public RestRunAnalyticsAction(Settings settings, RestController controller) { + public RestStartDataFrameAnalyticsAction(Settings settings, RestController controller) { super(settings); - controller.registerHandler(RestRequest.Method.POST, MachineLearning.BASE_PATH + "analytics/{index}/_run", this); + controller.registerHandler(RestRequest.Method.POST, MachineLearning.BASE_PATH + "data_frame/analytics/{" + + DataFrameAnalyticsConfig.ID.getPreferredName() + "}/_start", this); } @Override public String getName() { - return "xpack_ml_run_analytics_action"; + return "xpack_ml_start_data_frame_analytics_action"; } @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - RunAnalyticsAction.Request request = new RunAnalyticsAction.Request(restRequest.param("index")); - return channel -> { - client.execute(RunAnalyticsAction.INSTANCE, request, new RestToXContentListener<>(channel)); - }; + String id = restRequest.param(DataFrameAnalyticsConfig.ID.getPreferredName()); + StartDataFrameAnalyticsAction.Request request = new StartDataFrameAnalyticsAction.Request(id); + + return channel -> client.execute(StartDataFrameAnalyticsAction.INSTANCE, request, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtilsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtilsTests.java new file mode 100644 index 0000000000000..b95fd32f7288d --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtilsTests.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.analyses; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class DataFrameAnalysesUtilsTests extends ESTestCase { + + public void testReadAnalysis_GivenEmptyAnalysisList() { + assertThat(DataFrameAnalysesUtils.readAnalyses(Collections.emptyList()).isEmpty(), is(true)); + } + + public void testReadAnalysis_GivenUnknownAnalysis() { + String analysisJson = "{\"unknown_analysis\": {}}"; + DataFrameAnalysisConfig analysisConfig = createAnalysisConfig(analysisJson); + + ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, + () -> DataFrameAnalysesUtils.readAnalyses(Collections.singletonList(analysisConfig))); + + assertThat(e.getMessage(), equalTo("Unknown analysis type [unknown_analysis]")); + } + + public void testReadAnalysis_GivenAnalysisIsNotAnObject() { + String analysisJson = "{\"outlier_detection\": 42}"; + DataFrameAnalysisConfig analysisConfig = createAnalysisConfig(analysisJson); + + ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, + () -> DataFrameAnalysesUtils.readAnalyses(Collections.singletonList(analysisConfig))); + + assertThat(e.getMessage(), equalTo("[outlier_detection] expected to be a map but was of type [java.lang.Integer]")); + } + + public void testReadAnalysis_GivenUnusedParameters() { + String analysisJson = "{\"outlier_detection\": {\"number_neighbours\":42, \"foo\": 1}}"; + DataFrameAnalysisConfig analysisConfig = createAnalysisConfig(analysisJson); + + ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, + () -> DataFrameAnalysesUtils.readAnalyses(Collections.singletonList(analysisConfig))); + + assertThat(e.getMessage(), equalTo("Data frame analysis [outlier_detection] does not support one or more provided " + + "parameters [foo]")); + } + + public void testReadAnalysis_GivenValidOutlierDetection() { + String analysisJson = "{\"outlier_detection\": {\"number_neighbours\":42}}"; + DataFrameAnalysisConfig analysisConfig = createAnalysisConfig(analysisJson); + + List analyses = DataFrameAnalysesUtils.readAnalyses(Collections.singletonList(analysisConfig)); + + assertThat(analyses.size(), equalTo(1)); + assertThat(analyses.get(0), is(instanceOf(OutlierDetection.class))); + OutlierDetection outlierDetection = (OutlierDetection) analyses.get(0); + assertThat(outlierDetection.getParams().size(), equalTo(1)); + assertThat(outlierDetection.getParams().get("number_neighbours"), equalTo(42)); + } + + private static DataFrameAnalysisConfig createAnalysisConfig(String json) { + Map asMap = XContentHelper.convertToMap(new BytesArray(json), true, XContentType.JSON).v2(); + return new DataFrameAnalysisConfig(asMap); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetectionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetectionTests.java new file mode 100644 index 0000000000000..59a838acc8cd8 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetectionTests.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.analyses; + +import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.test.ESTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class OutlierDetectionTests extends ESTestCase { + + public void testCreate_GivenNumberNeighboursNotInt() { + Map config = new HashMap<>(); + config.put(OutlierDetection.NUMBER_NEIGHBOURS, "42"); + + DataFrameAnalysis.Factory factory = new OutlierDetection.Factory(); + + ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, () -> factory.create(config)); + assertThat(e.getMessage(), equalTo("Property [number_neighbours] of analysis [outlier_detection] should be of " + + "type [Integer] but was [String]")); + } + + public void testCreate_GivenMethodNotString() { + Map config = new HashMap<>(); + config.put(OutlierDetection.METHOD, 42); + + DataFrameAnalysis.Factory factory = new OutlierDetection.Factory(); + + ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, () -> factory.create(config)); + assertThat(e.getMessage(), equalTo("Property [method] of analysis [outlier_detection] should be of " + + "type [String] but was [Integer]")); + } + + public void testCreate_GivenEmptyParams() { + DataFrameAnalysis.Factory factory = new OutlierDetection.Factory(); + OutlierDetection outlierDetection = (OutlierDetection) factory.create(Collections.emptyMap()); + assertThat(outlierDetection.getParams().isEmpty(), is(true)); + } + + public void testCreate_GivenFullParams() { + Map config = new HashMap<>(); + config.put(OutlierDetection.NUMBER_NEIGHBOURS, 42); + config.put(OutlierDetection.METHOD, "ldof"); + + DataFrameAnalysis.Factory factory = new OutlierDetection.Factory(); + OutlierDetection outlierDetection = (OutlierDetection) factory.create(config); + + assertThat(outlierDetection.getParams().size(), equalTo(2)); + assertThat(outlierDetection.getParams().get(OutlierDetection.NUMBER_NEIGHBOURS), equalTo(42)); + assertThat(outlierDetection.getParams().get(OutlierDetection.METHOD), equalTo(OutlierDetection.Method.LDOF)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java similarity index 91% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactoryTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java index 82c492f579e10..9aac76da91e95 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/DataFrameDataExtractorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe; +package org.elasticsearch.xpack.ml.dataframe.extractor; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.fieldcaps.FieldCapabilities; @@ -26,11 +26,13 @@ public class DataFrameDataExtractorFactoryTests extends ESTestCase { + private static final String INDEX = "source_index"; + public void testDetectExtractedFields_GivenFloatField() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float").build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); @@ -42,7 +44,7 @@ public void testDetectExtractedFields_GivenNumericFieldWithMultipleTypes() { .addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float") .build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); @@ -54,8 +56,8 @@ public void testDetectExtractedFields_GivenNonNumericField() { .addAggregatableField("some_keyword", "keyword").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities)); - assertThat(e.getMessage(), equalTo("No compatible fields could be detected")); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities)); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } public void testDetectExtractedFields_GivenFieldWithNumericAndNonNumericTypes() { @@ -63,8 +65,8 @@ public void testDetectExtractedFields_GivenFieldWithNumericAndNonNumericTypes() .addAggregatableField("indecisive_field", "float", "keyword").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities)); - assertThat(e.getMessage(), equalTo("No compatible fields could be detected")); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities)); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } public void testDetectExtractedFields_GivenMultipleFields() { @@ -74,7 +76,7 @@ public void testDetectExtractedFields_GivenMultipleFields() { .addAggregatableField("some_keyword", "keyword") .build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(2)); @@ -87,8 +89,8 @@ public void testDetectExtractedFields_GivenIgnoredField() { .addAggregatableField("_id", "float").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities)); - assertThat(e.getMessage(), equalTo("No compatible fields could be detected")); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities)); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } public void testDetectExtractedFields_ShouldSortFieldsAlphabetically() { @@ -106,7 +108,7 @@ public void testDetectExtractedFields_ShouldSortFieldsAlphabetically() { } FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(fieldCapabilities); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 1d38049ff23ce..15d2e012080e1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -16,7 +16,7 @@ import org.elasticsearch.common.text.Text; import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.ml.dataframe.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.junit.Before; import org.mockito.ArgumentCaptor; diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json new file mode 100644 index 0000000000000..a09259fabb8be --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json @@ -0,0 +1,17 @@ +{ + "ml.delete_data_frame_analytics": { + "methods": [ "DELETE" ], + "url": { + "path": "/_ml/data_frame/analytics/{id}", + "paths": [ "/_ml/data_frame/analytics/{id}" ], + "parts": { + "id": { + "type" : "string", + "required" : true, + "description" : "The ID of the data frame analytics to delete" + } + } + }, + "body": null + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json new file mode 100644 index 0000000000000..9c65661d254a4 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json @@ -0,0 +1,29 @@ +{ + "ml.get_data_frame_analytics": { + "methods": [ "GET"], + "url": { + "path": "/_ml/data_frame/analytics/{id}", + "paths": [ + "/_ml/data_frame/analytics/{id}", + "/_ml/data_frame/analytics" + ], + "parts": { + "id": { + "type": "string", + "description": "The ID of the data frame analytics to fetch" + } + }, + "params": { + "from": { + "type": "int", + "description": "skips a number of analytics" + }, + "size": { + "type": "int", + "description": "specifies a max number of analytics to get" + } + } + }, + "body": null + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json new file mode 100644 index 0000000000000..d74f5880c72de --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json @@ -0,0 +1,29 @@ +{ + "ml.get_data_frame_analytics_stats": { + "methods": [ "GET"], + "url": { + "path": "/_ml/data_frame/analytics/{id}/_stats", + "paths": [ + "/_ml/data_frame/analytics/_stats", + "/_ml/data_frame/analytics/{id}/_stats" + ], + "parts": { + "id": { + "type": "string", + "description": "The ID of the data frame analytics stats to fetch" + } + }, + "params": { + "from": { + "type": "int", + "description": "skips a number of analytics" + }, + "size": { + "type": "int", + "description": "specifies a max number of analytics to get" + } + } + }, + "body": null + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json new file mode 100644 index 0000000000000..1f3183920aca5 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json @@ -0,0 +1,20 @@ +{ + "ml.put_data_frame_analytics": { + "methods": [ "PUT" ], + "url": { + "path": "/_ml/data_frame/analytics/{id}", + "paths": [ "/_ml/data_frame/analytics/{id}" ], + "parts": { + "id": { + "type": "string", + "required": true, + "description": "The ID of the data frame analytics to create" + } + } + }, + "body": { + "description" : "The data frame analytics configuration", + "required" : true + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json new file mode 100644 index 0000000000000..b4e61b3fab125 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json @@ -0,0 +1,16 @@ +{ + "ml.start_data_frame_analytics": { + "methods": [ "POST" ], + "url": { + "path": "/_ml/data_frame/analytics/{id}/_start", + "paths": [ "/_ml/data_frame/analytics/{id}/_start" ], + "parts": { + "id": { + "type": "string", + "required": true, + "description": "The ID of the data frame analytics to start" + } + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml new file mode 100644 index 0000000000000..85e07d63cba4a --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -0,0 +1,370 @@ +--- +"Test get-all and get-all-stats given no analytics exist": + + - do: + ml.get_data_frame_analytics: + id: "_all" + - match: { count: 0 } + - match: { data_frame_analytics: [] } + + - do: + ml.get_data_frame_analytics: + id: "_all" + - match: { count: 0 } + - match: { data_frame_analytics: [] } + + - do: + ml.get_data_frame_analytics: + id: "*" + - match: { count: 0 } + - match: { data_frame_analytics: [] } + + - do: + ml.get_data_frame_analytics: + id: "*" + - match: { count: 0 } + - match: { data_frame_analytics: [] } + +--- +"Test put valid config with default outlier detection": + + - do: + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": "source_index", + "dest": "dest_index", + "analyses": [{"outlier_detection":{}}] + } + - match: { id: "simple-outlier-detection" } + - match: { source: "source_index" } + - match: { dest: "dest_index" } + - match: { analyses: [{"outlier_detection":{}}] } + +--- +"Test put config with inconsistent body/param ids": + + - do: + catch: /Inconsistent id; 'body_id' specified in the body differs from 'url_id' specified as a URL argument/ + ml.put_data_frame_analytics: + id: "url_id" + body: > + { + "id": "body_id", + "source": "source_index", + "dest": "dest_index", + "analyses": [{"outlier_detection":{}}] + } + +--- +"Test put config with invalid id": + + - do: + catch: /Invalid id*/ + ml.put_data_frame_analytics: + id: "this id contains spaces" + body: > + { + "source": "source_index", + "dest": "dest_index", + "analyses": [{"outlier_detection":{}}] + } + +--- +"Test put config with unknown top level field": + + - do: + catch: /unknown field \[unknown_field\], parser not found/ + ml.put_data_frame_analytics: + id: "unknown_field" + body: > + { + "source": "source_index", + "dest": "dest_index", + "analyses": [{"outlier_detection":{}}], + "unknown_field": 42 + } + +--- +"Test put config with unknown field in outlier detection analysis": + + - do: + catch: /Data frame analysis \[outlier_detection\] does not support one or more provided parameters \[unknown_field\]/ + ml.put_data_frame_analytics: + id: "unknown_field" + body: > + { + "source": "source_index", + "dest": "dest_index", + "analyses": [{"outlier_detection":{"unknown_field": 42}}] + } + +--- +"Test put config given missing source": + + - do: + catch: /\[source\] must not be null/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "dest": "dest_index", + "analyses": [{"outlier_detection":{}}] + } + +--- +"Test put config given missing dest": + + - do: + catch: /\[dest\] must not be null/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": "source_index", + "analyses": [{"outlier_detection":{}}] + } + +--- +"Test put config given missing analyses": + + - do: + catch: /\[analyses\] must not be null/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": "source_index", + "dest": "dest_index" + } + +--- +"Test put config given empty analyses": + + - do: + catch: /One or more analyses are required/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": "source_index", + "dest": "dest_index", + "analyses": [] + } + +--- +"Test put config given two analyses": + + - do: + catch: /Does not yet support multiple analyses/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": "source_index", + "dest": "dest_index", + "analyses": [{"outlier_detection":{}}, {"outlier_detection":{}}] + } + +--- +"Test get given multiple analytics": + + - do: + ml.put_data_frame_analytics: + id: "foo-1" + body: > + { + "source": "foo-1_source", + "dest": "foo-1_dest", + "analyses": [{"outlier_detection":{}}] + } + + - do: + ml.put_data_frame_analytics: + id: "foo-2" + body: > + { + "source": "foo-2_source", + "dest": "foo-2_dest", + "analyses": [{"outlier_detection":{}}] + } + - match: { id: "foo-2" } + + - do: + ml.put_data_frame_analytics: + id: "bar" + body: > + { + "source": "bar_source", + "dest": "bar_dest", + "analyses": [{"outlier_detection":{}}] + } + - match: { id: "bar" } + + - do: + ml.get_data_frame_analytics: + id: "*" + - match: { count: 3 } + - match: { data_frame_analytics.0.id: "bar" } + - match: { data_frame_analytics.1.id: "foo-1" } + - match: { data_frame_analytics.2.id: "foo-2" } + + - do: + ml.get_data_frame_analytics: + id: "foo-*" + - match: { count: 2 } + - match: { data_frame_analytics.0.id: "foo-1" } + - match: { data_frame_analytics.1.id: "foo-2" } + + - do: + ml.get_data_frame_analytics: + id: "bar" + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "bar" } + + - do: + ml.get_data_frame_analytics: + from: 1 + - match: { count: 2 } + - match: { data_frame_analytics.0.id: "foo-1" } + - match: { data_frame_analytics.1.id: "foo-2" } + + - do: + ml.get_data_frame_analytics: + size: 2 + - match: { count: 2 } + - match: { data_frame_analytics.0.id: "bar" } + - match: { data_frame_analytics.1.id: "foo-1" } + + - do: + ml.get_data_frame_analytics: + from: 1 + size: 1 + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "foo-1" } + +--- +"Test get given missing analytics": + + - do: + catch: missing + ml.get_data_frame_analytics: + id: "missing-analytics" + +--- +"Test get stats given multiple analytics": + + - do: + ml.put_data_frame_analytics: + id: "foo-1" + body: > + { + "source": "foo-1_source", + "dest": "foo-1_dest", + "analyses": [{"outlier_detection":{}}] + } + + - do: + ml.put_data_frame_analytics: + id: "foo-2" + body: > + { + "source": "foo-2_source", + "dest": "foo-2_dest", + "analyses": [{"outlier_detection":{}}] + } + - match: { id: "foo-2" } + + - do: + ml.put_data_frame_analytics: + id: "bar" + body: > + { + "source": "bar_source", + "dest": "bar_dest", + "analyses": [{"outlier_detection":{}}] + } + - match: { id: "bar" } + + - do: + ml.get_data_frame_analytics_stats: + id: "*" + - match: { count: 3 } + - match: { data_frame_analytics.0.id: "bar" } + - match: { data_frame_analytics.0.state: "stopped" } + - match: { data_frame_analytics.1.id: "foo-1" } + - match: { data_frame_analytics.1.state: "stopped" } + - match: { data_frame_analytics.2.id: "foo-2" } + - match: { data_frame_analytics.2.state: "stopped" } + + - do: + ml.get_data_frame_analytics_stats: + id: "foo-*" + - match: { count: 2 } + - match: { data_frame_analytics.0.id: "foo-1" } + - match: { data_frame_analytics.0.state: "stopped" } + - match: { data_frame_analytics.1.id: "foo-2" } + - match: { data_frame_analytics.1.state: "stopped" } + + - do: + ml.get_data_frame_analytics_stats: + id: "bar" + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "bar" } + - match: { data_frame_analytics.0.state: "stopped" } + + - do: + ml.get_data_frame_analytics_stats: + from: 2 + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "foo-2" } + - match: { data_frame_analytics.0.state: "stopped" } + + - do: + ml.get_data_frame_analytics_stats: + size: 2 + - match: { count: 2 } + - match: { data_frame_analytics.0.id: "bar" } + - match: { data_frame_analytics.0.state: "stopped" } + - match: { data_frame_analytics.1.id: "foo-1" } + - match: { data_frame_analytics.1.state: "stopped" } + + - do: + ml.get_data_frame_analytics_stats: + from: 1 + size: 1 + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "foo-1" } + - match: { data_frame_analytics.0.state: "stopped" } + +--- +"Test delete given stopped config": + + - do: + ml.put_data_frame_analytics: + id: "foo" + body: > + { + "source": "source", + "dest": "dest", + "analyses": [{"outlier_detection":{}}] + } + + - do: + ml.delete_data_frame_analytics: + id: "foo" + - match: { acknowledged: true } + + - do: + catch: missing + ml.get_data_frame_analytics: + id: "foo" + +--- +"Test delete given missing config": + + - do: + catch: missing + ml.delete_data_frame_analytics: + id: "missing_config" diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml new file mode 100644 index 0000000000000..36e50b0229737 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml @@ -0,0 +1,46 @@ +--- +"Test start given missing config": + - do: + catch: missing + ml.start_data_frame_analytics: + id: "missing_config" + +--- +"Test start given missing source index": + + - do: + ml.put_data_frame_analytics: + id: "missing_index" + body: > + { + "source": "missing", + "dest": "missing-dest", + "analyses": [{"outlier_detection":{}}] + } + + - do: + catch: /cannot retrieve data because index \[missing\] does not exist/ + ml.start_data_frame_analytics: + id: "missing_index" + +--- +"Test start given source index has no compatible fields": + + - do: + indices.create: + index: empty-index + + - do: + ml.put_data_frame_analytics: + id: "foo" + body: > + { + "source": "empty-index", + "dest": "empty-index-dest", + "analyses": [{"outlier_detection":{}}] + } + + - do: + catch: /No compatible fields could be detected in index \[empty-index\]/ + ml.start_data_frame_analytics: + id: "foo" From ceab407b317034ada82ca05c82fc1ea9bd39f019 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 5 Feb 2019 18:17:38 +0200 Subject: [PATCH 16/67] =?UTF-8?q?[FEATURE][ML]=20Allow=20parsing=20differe?= =?UTF-8?q?nt=20types=20of=20results=20from=20analytics=E2=80=A6=20(#38430?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process/AnalyticsProcessManager.java | 8 +- .../ml/dataframe/process/AnalyticsResult.java | 34 ++--- .../process/AnalyticsResultProcessor.java | 86 ++---------- .../process/DataFrameRowsJoiner.java | 126 +++++++++++++++++ .../dataframe/process/results/RowResults.java | 73 ++++++++++ .../AnalyticsResultProcessorTests.java | 107 +++----------- .../process/AnalyticsResultTests.java | 16 +-- .../process/DataFrameRowsJoinerTests.java | 131 ++++++++++++++++++ .../process/results/RowResultsTests.java | 42 ++++++ 9 files changed, 423 insertions(+), 200 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 168662e5cf3c0..5868cefe3a30e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -16,10 +16,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import java.io.IOException; import java.util.List; @@ -51,7 +51,9 @@ public void runJob(DataFrameAnalyticsConfig config, DataFrameDataExtractorFactor DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); AnalyticsProcess process = createProcess(config.getId(), createProcessConfig(config, dataExtractor)); ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); - AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(client, dataExtractorFactory.newExtractor(true)); + DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, + dataExtractorFactory.newExtractor(true)); + AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(dataFrameRowsJoiner); executorService.execute(() -> resultProcessor.process(process)); executorService.execute(() -> processData(config.getId(), dataExtractor, process, resultProcessor, finishHandler)); }); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java index 3e9c1b8b9cd57..4d15bf89b29e9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java @@ -9,46 +9,38 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import java.io.IOException; -import java.util.Map; import java.util.Objects; public class AnalyticsResult implements ToXContentObject { public static final ParseField TYPE = new ParseField("analytics_result"); - public static final ParseField CHECKSUM = new ParseField("checksum"); - public static final ParseField RESULTS = new ParseField("results"); static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), - a -> new AnalyticsResult((Integer) a[0], (Map) a[1])); + a -> new AnalyticsResult((RowResults) a[0])); static { - PARSER.declareInt(ConstructingObjectParser.constructorArg(), CHECKSUM); - PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); } - private final int checksum; - private final Map results; + private final RowResults rowResults; - public AnalyticsResult(int checksum, Map results) { - this.checksum = Objects.requireNonNull(checksum); - this.results = Objects.requireNonNull(results); + public AnalyticsResult(RowResults rowResults) { + this.rowResults = rowResults; } - public int getChecksum() { - return checksum; - } - - public Map getResults() { - return results; + public RowResults getRowResults() { + return rowResults; } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(CHECKSUM.getPreferredName(), checksum); - builder.field(RESULTS.getPreferredName(), results); + if (rowResults != null) { + builder.field(RowResults.TYPE.getPreferredName(), rowResults); + } builder.endObject(); return builder; } @@ -63,11 +55,11 @@ public boolean equals(Object other) { } AnalyticsResult that = (AnalyticsResult) other; - return checksum == that.checksum && Objects.equals(results, that.results); + return Objects.equals(rowResults, that.rowResults); } @Override public int hashCode() { - return Objects.hash(checksum, results); + return Objects.hash(rowResults); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index cd6864f049740..1b70d68598df6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -7,22 +7,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.bulk.BulkAction; -import org.elasticsearch.action.bulk.BulkRequest; -import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.client.Client; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; -import java.util.ArrayList; import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -30,15 +18,11 @@ public class AnalyticsResultProcessor { private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); - private final Client client; - private final DataFrameDataExtractor dataExtractor; - private List currentDataFrameRows; - private List currentResults; + private final DataFrameRowsJoiner dataFrameRowsJoiner; private final CountDownLatch completionLatch = new CountDownLatch(1); - public AnalyticsResultProcessor(Client client, DataFrameDataExtractor dataExtractor) { - this.client = Objects.requireNonNull(client); - this.dataExtractor = Objects.requireNonNull(dataExtractor); + public AnalyticsResultProcessor(DataFrameRowsJoiner dataFrameRowsJoiner) { + this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); } public void awaitForCompletion() { @@ -57,28 +41,8 @@ public void process(AnalyticsProcess process) { try { Iterator iterator = process.readAnalyticsResults(); while (iterator.hasNext()) { - try { - AnalyticsResult result = iterator.next(); - if (dataExtractor.hasNext() == false) { - return; - } - if (currentDataFrameRows == null) { - Optional> nextBatch = dataExtractor.next(); - if (nextBatch.isPresent() == false) { - return; - } - currentDataFrameRows = nextBatch.get(); - currentResults = new ArrayList<>(currentDataFrameRows.size()); - } - currentResults.add(result); - if (currentResults.size() == currentDataFrameRows.size()) { - joinCurrentResults(); - currentDataFrameRows = null; - } - } catch (Exception e) { - LOGGER.warn("Error processing data frame analytics result", e); - } - + AnalyticsResult result = iterator.next(); + processResult(result); } } catch (Exception e) { LOGGER.error("Error parsing data frame analytics output", e); @@ -88,40 +52,10 @@ public void process(AnalyticsProcess process) { } } - private void joinCurrentResults() { - BulkRequest bulkRequest = new BulkRequest(); - for (int i = 0; i < currentDataFrameRows.size(); i++) { - DataFrameDataExtractor.Row row = currentDataFrameRows.get(i); - if (row.shouldSkip()) { - continue; - } - AnalyticsResult result = currentResults.get(i); - checkChecksumsMatch(row, result); - - SearchHit hit = row.getHit(); - Map source = new LinkedHashMap(hit.getSourceAsMap()); - source.putAll(result.getResults()); - IndexRequest indexRequest = new IndexRequest(hit.getIndex(), hit.getType(), hit.getId()); - indexRequest.source(source); - indexRequest.opType(DocWriteRequest.OpType.INDEX); - bulkRequest.add(indexRequest); - } - if (bulkRequest.numberOfActions() > 0) { - BulkResponse bulkResponse = client.execute(BulkAction.INSTANCE, bulkRequest).actionGet(); - if (bulkResponse.hasFailures()) { - LOGGER.error("Failures while writing data frame"); - // TODO Better error handling - } - } - } - - private void checkChecksumsMatch(DataFrameDataExtractor.Row row, AnalyticsResult result) { - if (row.getChecksum() != result.getChecksum()) { - String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; "; - msg += "expected [" + row.getChecksum() + "] but result had [" + result.getChecksum() + "]; "; - msg += "this implies the data frame index [" + row.getHit().getIndex() + "] was modified while the analysis was running. "; - msg += "We rely on this index being immutable during a running analysis and so the results will be unreliable."; - throw new IllegalStateException(msg); + private void processResult(AnalyticsResult result) { + RowResults rowResults = result.getRowResults(); + if (rowResults != null) { + dataFrameRowsJoiner.processRowResults(rowResults); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java new file mode 100644 index 0000000000000..76ebe166a39ad --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +class DataFrameRowsJoiner { + + private static final Logger LOGGER = LogManager.getLogger(DataFrameRowsJoiner.class); + + private final String analyticsId; + private final Client client; + private final DataFrameDataExtractor dataExtractor; + private List currentDataFrameRows; + private List currentResults; + private boolean failed; + + DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) { + this.analyticsId = Objects.requireNonNull(analyticsId); + this.client = Objects.requireNonNull(client); + this.dataExtractor = Objects.requireNonNull(dataExtractor); + } + + void processRowResults(RowResults rowResults) { + if (failed) { + // If we are in failed state we drop the results but we let the processor + // parse the output + return; + } + + try { + addResultAndJoinIfEndOfBatch(rowResults); + } catch (Exception e) { + LOGGER.error(new ParameterizedMessage("[{}] Failed to join results", analyticsId), e); + failed = true; + } + } + + private void addResultAndJoinIfEndOfBatch(RowResults rowResults) { + if (currentDataFrameRows == null) { + Optional> nextBatch = getNextBatch(); + if (nextBatch.isPresent() == false) { + return; + } + currentDataFrameRows = nextBatch.get(); + currentResults = new ArrayList<>(currentDataFrameRows.size()); + } + currentResults.add(rowResults); + if (currentResults.size() == currentDataFrameRows.size()) { + joinCurrentResults(); + currentDataFrameRows = null; + } + } + + private Optional> getNextBatch() { + try { + return dataExtractor.next(); + } catch (IOException e) { + // TODO Implement recovery strategy or better error reporting + LOGGER.error("Error reading next batch of data frame rows", e); + return Optional.empty(); + } + } + + private void joinCurrentResults() { + BulkRequest bulkRequest = new BulkRequest(); + for (int i = 0; i < currentDataFrameRows.size(); i++) { + DataFrameDataExtractor.Row row = currentDataFrameRows.get(i); + if (row.shouldSkip()) { + continue; + } + RowResults result = currentResults.get(i); + checkChecksumsMatch(row, result); + + SearchHit hit = row.getHit(); + Map source = new LinkedHashMap(hit.getSourceAsMap()); + source.putAll(result.getResults()); + new IndexRequest(hit.getIndex()); + IndexRequest indexRequest = new IndexRequest(hit.getIndex()); + indexRequest.id(hit.getId()); + indexRequest.source(source); + indexRequest.opType(DocWriteRequest.OpType.INDEX); + bulkRequest.add(indexRequest); + } + if (bulkRequest.numberOfActions() > 0) { + BulkResponse bulkResponse = client.execute(BulkAction.INSTANCE, bulkRequest).actionGet(); + if (bulkResponse.hasFailures()) { + LOGGER.error("Failures while writing data frame"); + // TODO Better error handling + } + } + } + + private void checkChecksumsMatch(DataFrameDataExtractor.Row row, RowResults result) { + if (row.getChecksum() != result.getChecksum()) { + String msg = "Detected checksum mismatch for document with id [" + row.getHit().getId() + "]; "; + msg += "expected [" + row.getChecksum() + "] but result had [" + result.getChecksum() + "]; "; + msg += "this implies the data frame index [" + row.getHit().getIndex() + "] was modified while the analysis was running. "; + msg += "We rely on this index being immutable during a running analysis and so the results will be unreliable."; + throw new RuntimeException(msg); + // TODO Communicate this error to the user as effectively the analytics have failed (e.g. FAILED state, audit error, etc.) + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java new file mode 100644 index 0000000000000..ba4aebededa2e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResults.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class RowResults implements ToXContentObject { + + public static final ParseField TYPE = new ParseField("row_results"); + public static final ParseField CHECKSUM = new ParseField("checksum"); + public static final ParseField RESULTS = new ParseField("results"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), + a -> new RowResults((Integer) a[0], (Map) a[1])); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), CHECKSUM); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, context) -> p.map(), RESULTS); + } + + private final int checksum; + private final Map results; + + public RowResults(int checksum, Map results) { + this.checksum = Objects.requireNonNull(checksum); + this.results = Objects.requireNonNull(results); + } + + public int getChecksum() { + return checksum; + } + + public Map getResults() { + return results; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CHECKSUM.getPreferredName(), checksum); + builder.field(RESULTS.getPreferredName(), results); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + + RowResults that = (RowResults) other; + return checksum == that.checksum && Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(checksum, results); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 15d2e012080e1..cded955767344 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -5,47 +5,29 @@ */ package org.elasticsearch.xpack.ml.dataframe.process; -import org.elasticsearch.action.ActionFuture; -import org.elasticsearch.action.bulk.BulkAction; -import org.elasticsearch.action.bulk.BulkItemResponse; -import org.elasticsearch.action.bulk.BulkRequest; -import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.client.Client; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.text.Text; -import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.junit.Before; -import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.Mockito; -import java.io.IOException; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Optional; -import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; public class AnalyticsResultProcessorTests extends ESTestCase { - private Client client; private AnalyticsProcess process; - private DataFrameDataExtractor dataExtractor; - private ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + private DataFrameRowsJoiner dataFrameRowsJoiner; @Before public void setUpMocks() { - client = mock(Client.class); process = mock(AnalyticsProcess.class); - dataExtractor = mock(DataFrameDataExtractor.class); + dataFrameRowsJoiner = mock(DataFrameRowsJoiner.class); } public void testProcess_GivenNoResults() { @@ -55,93 +37,38 @@ public void testProcess_GivenNoResults() { resultProcessor.process(process); resultProcessor.awaitForCompletion(); - verifyNoMoreInteractions(client); + verifyNoMoreInteractions(dataFrameRowsJoiner); } - public void testProcess_GivenSingleRowAndResult() throws IOException { - givenClientHasNoFailures(); - - String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; - String[] dataValues = {"42.0"}; - DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); - givenSingleDataFrameBatch(Arrays.asList(row)); - - Map resultFields = new HashMap<>(); - resultFields.put("a", "1"); - resultFields.put("b", "2"); - AnalyticsResult result = new AnalyticsResult(1, resultFields); - givenProcessResults(Arrays.asList(result)); - + public void testProcess_GivenEmptyResults() { + givenProcessResults(Arrays.asList(new AnalyticsResult(null), new AnalyticsResult(null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); resultProcessor.awaitForCompletion(); - List capturedBulkRequests = bulkRequestCaptor.getAllValues(); - assertThat(capturedBulkRequests.size(), equalTo(1)); - BulkRequest capturedBulkRequest = capturedBulkRequests.get(0); - assertThat(capturedBulkRequest.numberOfActions(), equalTo(1)); - IndexRequest indexRequest = (IndexRequest) capturedBulkRequest.requests().get(0); - Map indexedDocSource = indexRequest.sourceAsMap(); - assertThat(indexedDocSource.size(), equalTo(4)); - assertThat(indexedDocSource.get("f_1"), equalTo("foo")); - assertThat(indexedDocSource.get("f_2"), equalTo(42.0)); - assertThat(indexedDocSource.get("a"), equalTo("1")); - assertThat(indexedDocSource.get("b"), equalTo("2")); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner); } - public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IOException { - givenClientHasNoFailures(); - - String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; - String[] dataValues = {"42.0"}; - DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); - givenSingleDataFrameBatch(Arrays.asList(row)); - - Map resultFields = new HashMap<>(); - resultFields.put("a", "1"); - resultFields.put("b", "2"); - AnalyticsResult result = new AnalyticsResult(2, resultFields); - givenProcessResults(Arrays.asList(result)); - + public void testProcess_GivenRowResults() { + RowResults rowResults1 = mock(RowResults.class); + RowResults rowResults2 = mock(RowResults.class); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1), new AnalyticsResult(rowResults2))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); resultProcessor.awaitForCompletion(); - verifyNoMoreInteractions(client); + InOrder inOrder = Mockito.inOrder(dataFrameRowsJoiner); + inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1); + inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2); } private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } - private void givenSingleDataFrameBatch(List batch) throws IOException { - when(dataExtractor.hasNext()).thenReturn(true).thenReturn(true).thenReturn(false); - when(dataExtractor.next()).thenReturn(Optional.of(batch)).thenReturn(Optional.empty()); - } - - private static SearchHit newHit(String json) { - SearchHit hit = new SearchHit(randomInt(), randomAlphaOfLength(10), new Text("doc"), Collections.emptyMap()); - hit.sourceRef(new BytesArray(json)); - return hit; - } - - private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, int checksum) { - DataFrameDataExtractor.Row row = mock(DataFrameDataExtractor.Row.class); - when(row.getHit()).thenReturn(hit); - when(row.getValues()).thenReturn(values); - when(row.getChecksum()).thenReturn(checksum); - return row; - } - - private void givenClientHasNoFailures() { - ActionFuture responseFuture = mock(ActionFuture.class); - when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0)); - when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture); - } - private AnalyticsResultProcessor createResultProcessor() { - return new AnalyticsResultProcessor(client, dataExtractor); + return new AnalyticsResultProcessor(dataFrameRowsJoiner); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java index fc46b4e984d26..d0d3b4ee5f99c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java @@ -7,24 +7,20 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResultsTests; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; public class AnalyticsResultTests extends AbstractXContentTestCase { @Override protected AnalyticsResult createTestInstance() { - int checksum = randomInt(); - Map results = new HashMap<>(); - int resultsSize = randomIntBetween(1, 10); - for (int i = 0; i < resultsSize; i++) { - String resultField = randomAlphaOfLength(20); - Object resultValue = randomBoolean() ? randomAlphaOfLength(20) : randomDouble(); - results.put(resultField, resultValue); + RowResults rowResults = null; + if (randomBoolean()) { + rowResults = RowResultsTests.createRandom(); } - return new AnalyticsResult(checksum, results); + return new AnalyticsResult(rowResults); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java new file mode 100644 index 0000000000000..4c6d9e78a9300 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class DataFrameRowsJoinerTests extends ESTestCase { + + private static final String ANALYTICS_ID = "my_analytics"; + + private Client client; + private DataFrameDataExtractor dataExtractor; + private DataFrameRowsJoiner dataFrameRowsJoiner; + private ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); + + @Before + public void setUpMocks() { + client = mock(Client.class); + dataExtractor = mock(DataFrameDataExtractor.class); + dataFrameRowsJoiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor); + } + + public void testProcess_GivenNoResults() { + givenProcessResults(Collections.emptyList()); + verifyNoMoreInteractions(client); + } + + public void testProcess_GivenSingleRowAndResult() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); + givenSingleDataFrameBatch(Arrays.asList(row)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + RowResults result = new RowResults(1, resultFields); + givenProcessResults(Arrays.asList(result)); + + List capturedBulkRequests = bulkRequestCaptor.getAllValues(); + assertThat(capturedBulkRequests.size(), equalTo(1)); + BulkRequest capturedBulkRequest = capturedBulkRequests.get(0); + assertThat(capturedBulkRequest.numberOfActions(), equalTo(1)); + IndexRequest indexRequest = (IndexRequest) capturedBulkRequest.requests().get(0); + Map indexedDocSource = indexRequest.sourceAsMap(); + assertThat(indexedDocSource.size(), equalTo(4)); + assertThat(indexedDocSource.get("f_1"), equalTo("foo")); + assertThat(indexedDocSource.get("f_2"), equalTo(42.0)); + assertThat(indexedDocSource.get("a"), equalTo("1")); + assertThat(indexedDocSource.get("b"), equalTo("2")); + } + + public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); + givenSingleDataFrameBatch(Arrays.asList(row)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + RowResults result = new RowResults(2, resultFields); + givenProcessResults(Arrays.asList(result)); + + verifyNoMoreInteractions(client); + } + + private void givenProcessResults(List results) { + results.forEach(dataFrameRowsJoiner::processRowResults); + } + + private void givenSingleDataFrameBatch(List batch) throws IOException { + when(dataExtractor.hasNext()).thenReturn(true).thenReturn(true).thenReturn(false); + when(dataExtractor.next()).thenReturn(Optional.of(batch)).thenReturn(Optional.empty()); + } + + private static SearchHit newHit(String json) { + SearchHit hit = new SearchHit(randomInt(), randomAlphaOfLength(10), new Text("doc"), Collections.emptyMap()); + hit.sourceRef(new BytesArray(json)); + return hit; + } + + private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, int checksum) { + DataFrameDataExtractor.Row row = mock(DataFrameDataExtractor.Row.class); + when(row.getHit()).thenReturn(hit); + when(row.getValues()).thenReturn(values); + when(row.getChecksum()).thenReturn(checksum); + return row; + } + + private void givenClientHasNoFailures() { + ActionFuture responseFuture = mock(ActionFuture.class); + when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0)); + when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java new file mode 100644 index 0000000000000..5fdeee90329ae --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/RowResultsTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process.results; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class RowResultsTests extends AbstractXContentTestCase { + + @Override + protected RowResults createTestInstance() { + return createRandom(); + } + + public static RowResults createRandom() { + int checksum = randomInt(); + Map results = new HashMap<>(); + int resultsSize = randomIntBetween(1, 10); + for (int i = 0; i < resultsSize; i++) { + String resultField = randomAlphaOfLength(20); + Object resultValue = randomBoolean() ? randomAlphaOfLength(20) : randomDouble(); + results.put(resultField, resultValue); + } + return new RowResults(checksum, results); + } + + @Override + protected RowResults doParseInstance(XContentParser parser) { + return RowResults.PARSER.apply(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } +} From 267ca3ac46ed4444ee307f000c7e3d27f9af4d83 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Tue, 12 Feb 2019 08:49:37 -0600 Subject: [PATCH 17/67] [FEATURE][ML] Add query support for dataframe analytics config (#38576) * ML: Add query support for dataframe analytics config * Adding query to reindex, if needed. Also fixing minor bug in extractor * Adding default query to config, adjusting extractor factory * adjusting analytics extractor factory * Adjust config parse to not store default fields of parsed query * fixing reindex and yml tests * Only querying on reindex analytics run * removing unused function --- .../dataframe/DataFrameAnalyticsConfig.java | 108 ++++++++++++++-- .../xpack/core/ml/job/messages/Messages.java | 2 + ...tDataFrameAnalyticsActionRequestTests.java | 18 +++ .../DataFrameAnalyticsConfigTests.java | 117 +++++++++++++++++- ...ransportStartDataFrameAnalyticsAction.java | 8 +- .../dataframe/DataFrameAnalyticsManager.java | 5 +- .../DataFrameDataExtractorFactory.java | 91 ++++++++++---- .../test/ml/data_frame_analytics_crud.yml | 20 +++ 8 files changed, 329 insertions(+), 40 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index 028e4347ba8f0..d982dc3bfa91a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -5,30 +5,63 @@ */ package org.elasticsearch.xpack.core.ml.dataframe; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchParseException; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { + private static final Logger logger = LogManager.getLogger(DataFrameAnalyticsConfig.class); public static final String TYPE = "data_frame_analytics_config"; + private static final XContentObjectTransformer QUERY_TRANSFORMER = XContentObjectTransformer.queryBuilderTransformer(); + static final TriFunction, String, List, QueryBuilder> lazyQueryParser = + (objectMap, id, warnings) -> { + try { + return QUERY_TRANSFORMER.fromMap(objectMap, warnings); + } catch (IOException | XContentParseException exception) { + // Certain thrown exceptions wrap up the real Illegal argument making it hard to determine cause for the user + if (exception.getCause() instanceof IllegalArgumentException) { + throw ExceptionsHelper.badRequestException( + Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, id), exception.getCause()); + } else { + throw ExceptionsHelper.badRequestException( + Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, id), exception); + } + } + }; + + public static final ParseField ID = new ParseField("id"); public static final ParseField SOURCE = new ParseField("source"); public static final ParseField DEST = new ParseField("dest"); public static final ParseField ANALYSES = new ParseField("analyses"); public static final ParseField CONFIG_TYPE = new ParseField("config_type"); + public static final ParseField QUERY = new ParseField("query"); public static final ObjectParser STRICT_PARSER = createParser(false); public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -41,6 +74,7 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel parser.declareString(Builder::setSource, SOURCE); parser.declareString(Builder::setDest, DEST); parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES); + parser.declareObject((builder, query) -> builder.setQuery(query, ignoreUnknownFields), (p, c) -> p.mapOrdered(), QUERY); return parser; } @@ -48,8 +82,11 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel private final String source; private final String dest; private final List analyses; + private final Map query; + private final CachedSupplier querySupplier; - public DataFrameAnalyticsConfig(String id, String source, String dest, List analyses) { + public DataFrameAnalyticsConfig(String id, String source, String dest, List analyses, + Map query) { this.id = ExceptionsHelper.requireNonNull(id, ID); this.source = ExceptionsHelper.requireNonNull(source, SOURCE); this.dest = ExceptionsHelper.requireNonNull(dest, DEST); @@ -61,6 +98,8 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List 1) { throw new UnsupportedOperationException("Does not yet support multiple analyses"); } + this.query = Collections.unmodifiableMap(query); + this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); } public DataFrameAnalyticsConfig(StreamInput in) throws IOException { @@ -68,6 +107,8 @@ public DataFrameAnalyticsConfig(StreamInput in) throws IOException { source = in.readString(); dest = in.readString(); analyses = in.readList(DataFrameAnalysisConfig::new); + this.query = in.readMap(); + this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); } public String getId() { @@ -86,6 +127,30 @@ public List getAnalyses() { return analyses; } + @Nullable + public Map getQuery() { + return query; + } + + @Nullable + public QueryBuilder getParsedQuery() { + return querySupplier.get(); + } + + /** + * Calls the lazy parser and returns any gathered deprecations + * @return The deprecations from parsing the query + */ + List getQueryDeprecations() { + return getQueryDeprecations(lazyQueryParser); + } + + List getQueryDeprecations(TriFunction, String, List, QueryBuilder> parser) { + List deprecations = new ArrayList<>(); + parser.apply(query, id, deprecations); + return deprecations; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -96,6 +161,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(ToXContentParams.INCLUDE_TYPE, false)) { builder.field(CONFIG_TYPE.getPreferredName(), TYPE); } + builder.field(QUERY.getPreferredName(), query); builder.endObject(); return builder; } @@ -106,6 +172,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(source); out.writeString(dest); out.writeList(analyses); + out.writeMap(query); } @Override @@ -117,12 +184,13 @@ public boolean equals(Object o) { return Objects.equals(id, other.id) && Objects.equals(source, other.source) && Objects.equals(dest, other.dest) - && Objects.equals(analyses, other.analyses); + && Objects.equals(analyses, other.analyses) + && Objects.equals(query, other.query); } @Override public int hashCode() { - return Objects.hash(id, source, dest, analyses); + return Objects.hash(id, source, dest, analyses, query); } public static String documentId(String id) { @@ -135,29 +203,53 @@ public static class Builder { private String source; private String dest; private List analyses; + private Map query = Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()); public String getId() { return id; } - public void setId(String id) { + public Builder setId(String id) { this.id = ExceptionsHelper.requireNonNull(id, ID); + return this; } - public void setSource(String source) { + public Builder setSource(String source) { this.source = ExceptionsHelper.requireNonNull(source, SOURCE); + return this; } - public void setDest(String dest) { + public Builder setDest(String dest) { this.dest = ExceptionsHelper.requireNonNull(dest, DEST); + return this; } - public void setAnalyses(List analyses) { + public Builder setAnalyses(List analyses) { this.analyses = ExceptionsHelper.requireNonNull(analyses, ANALYSES); + return this; + } + + public Builder setQuery(Map query) { + return setQuery(query, true); + } + + public Builder setQuery(Map query, boolean lenient) { + this.query = ExceptionsHelper.requireNonNull(query, QUERY.getPreferredName()); + try { + QUERY_TRANSFORMER.fromMap(query); + } catch (Exception exception) { + if (lenient) { + logger.warn(Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, id), exception); + } else { + throw ExceptionsHelper.badRequestException( + Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, id), exception); + } + } + return this; } public DataFrameAnalyticsConfig build() { - return new DataFrameAnalyticsConfig(id, source, dest, analyses); + return new DataFrameAnalyticsConfig(id, source, dest, analyses, query); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 77ae8cb26eae9..4bd3ccbc985d7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -50,6 +50,8 @@ public final class Messages { "Datafeed frequency [{0}] must be a multiple of the aggregation interval [{1}]"; public static final String DATAFEED_ID_ALREADY_TAKEN = "A datafeed with id [{0}] already exists"; + public static final String DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT = "Data Frame Analytics config [{0}] query is not parsable"; + public static final String FILTER_CANNOT_DELETE = "Cannot delete filter [{0}] currently used by jobs {1}"; public static final String FILTER_CONTAINS_TOO_MANY_ITEMS = "Filter [{0}] contains too many items; up to [{1}] items are allowed"; public static final String FILTER_NOT_FOUND = "No filter with id [{0}] exists"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java index 2d899b7fb2d44..633f34fd88576 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java @@ -5,12 +5,18 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractStreamableXContentTestCase; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Request; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.junit.Before; +import java.util.Collections; + public class PutDataFrameAnalyticsActionRequestTests extends AbstractStreamableXContentTestCase { private String id; @@ -20,6 +26,18 @@ public void setUpId() { id = DataFrameAnalyticsConfigTests.randomValidId(); } + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } + @Override protected Request createTestInstance() { return new Request(DataFrameAnalyticsConfigTests.createRandom(id)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index dfb3e8ef36546..d58c0e5355987 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -6,14 +6,31 @@ package org.elasticsearch.xpack.core.ml.dataframe; import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; import java.util.Collections; import java.util.List; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { @Override @@ -21,6 +38,18 @@ protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws return DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build(); } + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } + @Override protected DataFrameAnalyticsConfig createTestInstance() { return createRandom(randomValidId()); @@ -35,11 +64,95 @@ public static DataFrameAnalyticsConfig createRandom(String id) { String source = randomAlphaOfLength(10); String dest = randomAlphaOfLength(10); List analyses = Collections.singletonList(DataFrameAnalysisConfigTests.randomConfig()); - return new DataFrameAnalyticsConfig(id, source, dest, analyses); + DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder() + .setId(id) + .setAnalyses(analyses) + .setSource(source) + .setDest(dest); + if (randomBoolean()) { + builder.setQuery( + Collections.singletonMap(TermQueryBuilder.NAME, + Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))), true); + } + return builder.build(); } public static String randomValidId() { - CodepointSetGenerator generator = new CodepointSetGenerator("abcdefghijklmnopqrstuvwxyz".toCharArray()); + CodepointSetGenerator generator = new CodepointSetGenerator("abcdefghijklmnopqrstuvwxyz".toCharArray()); return generator.ofCodePointsLength(random(), 10, 10); } + + private static final String ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS = "{\n" + + " \"id\": \"old-data-frame\",\n" + + " \"source\": \"my-index\",\n" + + " \"dest\": \"dest-index\",\n" + + " \"analyses\": {\"outlier_detection\": {\"number_neighbours\": 10}},\n" + + //query:match:type stopped being supported in 6.x + " \"query\": {\"match\" : {\"query\":\"fieldName\", \"type\": \"phrase\"}}\n" + + "}"; + + private static final String MODERN_QUERY_DATA_FRAME_ANALYTICS = "{\n" + + " \"id\": \"data-frame\",\n" + + " \"source\": \"my-index\",\n" + + " \"dest\": \"dest-index\",\n" + + " \"analyses\": {\"outlier_detection\": {\"number_neighbours\": 10}},\n" + + // match_all if parsed, adds default values in the options + " \"query\": {\"match_all\" : {}}\n" + + "}"; + + public void testQueryConfigStoresUserInputOnly() throws IOException { + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + MODERN_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(config.getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); + } + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + MODERN_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build(); + assertThat(config.getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); + } + } + + public void testPastQueryConfigParse() throws IOException { + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> config.getParsedQuery()); + assertEquals("[match] query doesn't support multiple fields, found [query] and [type]", e.getMessage()); + } + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { + + XContentParseException e = expectThrows(XContentParseException.class, + () -> DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build()); + assertEquals("[6:64] [data_frame_analytics_config] failed to parse field [query]", e.getMessage()); + } + } + + public void testGetQueryDeprecations() { + DataFrameAnalyticsConfig dataFrame = createTestInstance(); + String deprecationWarning = "Warning"; + List deprecations = dataFrame.getQueryDeprecations((map, id, deprecationlist) -> { + deprecationlist.add(deprecationWarning); + return new BoolQueryBuilder(); + }); + assertThat(deprecations, hasItem(deprecationWarning)); + + DataFrameAnalyticsConfig spiedConfig = spy(dataFrame); + spiedConfig.getQueryDeprecations(); + verify(spiedConfig).getQueryDeprecations(DataFrameAnalyticsConfig.lazyQueryParser); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index e45b16a6c6786..f518297915d1d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -122,16 +122,16 @@ public void onFailure(Exception e) { }; // Start persistent task - ActionListener validatedConfigListener = ActionListener.wrap( - config -> persistentTasksService.sendStartRequest(MlTasks.dataFrameAnalyticsTaskId(request.getId()), + ActionListener validateListener = ActionListener.wrap( + validated -> persistentTasksService.sendStartRequest(MlTasks.dataFrameAnalyticsTaskId(request.getId()), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, taskParams, waitForAnalyticsToStart), listener::onFailure ); // Validate config ActionListener configListener = ActionListener.wrap( - config -> DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), config.getId(), config.getSource(), - ActionListener.wrap(dataExtractorFactory -> validatedConfigListener.onResponse(config), listener::onFailure)), + config -> + DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, Collections.emptyMap(), config, validateListener), listener::onFailure ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 72b1319689820..ced87744d4eaa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -108,6 +108,8 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF createIndexResponse -> { ReindexRequest reindexRequest = new ReindexRequest(); reindexRequest.setSourceIndices(config.getSource()); + // we default to match_all + reindexRequest.setSourceQuery(config.getParsedQuery()); reindexRequest.setDestIndex(config.getDest()); reindexRequest.setScript(new Script("ctx._source." + DataFrameAnalyticsFields.ID + " = ctx._id")); client.execute(ReindexAction.INSTANCE, reindexRequest, reindexCompletedListener); @@ -142,8 +144,7 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi // TODO This could fail with errors. In that case we get stuck with the copied index. // We could delete the index in case of failure or we could try building the factory before reindexing // to catch the error early on. - DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), config.getId(), config.getDest(), - dataExtractorFactoryListener); + DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), config, dataExtractorFactoryListener); } private void createDestinationIndex(String sourceIndex, String destinationIndex, ActionListener listener) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index a9b225dee77ac..3d076ace1e348 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -16,6 +16,7 @@ import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; @@ -78,33 +79,47 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { return new DataFrameDataExtractor(client, context); } - public static void create(Client client, Map headers, String analyticsId, String index, + /** + * Validate and create a new extractor factory + * + * The destination index must exist and contain at least 1 compatible field or validations will fail. + * + * @param client ES Client used to make calls against the cluster + * @param headers Headers to use + * @param config The config from which to create the extractor factory + * @param listener The listener to notify on creation or failure + */ + public static void create(Client client, + Map headers, + DataFrameAnalyticsConfig config, ActionListener listener) { - // Step 2. Contruct the factory and notify listener - ActionListener fieldCapabilitiesHandler = ActionListener.wrap( - fieldCapabilitiesResponse -> { - listener.onResponse(new DataFrameDataExtractorFactory(client, analyticsId, index, - detectExtractedFields(index, fieldCapabilitiesResponse))); - }, e -> { - if (e instanceof IndexNotFoundException) { - listener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " - + ((IndexNotFoundException) e).getIndex() + " does not exist")); - } else { - listener.onFailure(e); - } - } - ); + validateIndexAndExtractFields(client, headers, config.getDest(), ActionListener.wrap( + extractedFields -> listener.onResponse( + new DataFrameDataExtractorFactory(client, config.getId(), config.getDest(), extractedFields)), + listener::onFailure + )); + } - // Step 1. Get field capabilities necessary to build the information of how to extract fields - FieldCapabilitiesRequest fieldCapabilitiesRequest = new FieldCapabilitiesRequest(); - fieldCapabilitiesRequest.indices(index); - fieldCapabilitiesRequest.fields("*"); - ClientHelper.executeWithHeaders(headers, ClientHelper.ML_ORIGIN, client, () -> { - client.execute(FieldCapabilitiesAction.INSTANCE, fieldCapabilitiesRequest, fieldCapabilitiesHandler); - // This response gets discarded - the listener handles the real response - return null; - }); + /** + * Validates the source index and analytics config + * + * @param client ES Client to make calls + * @param headers Headers for auth + * @param config Analytics config to validate + * @param listener The listener to notify on failure or completion + */ + public static void validateConfigAndSourceIndex(Client client, + Map headers, + DataFrameAnalyticsConfig config, + ActionListener listener) { + validateIndexAndExtractFields(client, headers, config.getSource(), ActionListener.wrap( + fields -> { + config.getParsedQuery(); // validate query is acceptable + listener.onResponse(true); + }, + listener::onFailure + )); } // Visible for testing @@ -133,4 +148,32 @@ private static void removeFieldsWithIncompatibleTypes(Set fields, FieldC } } } + + private static void validateIndexAndExtractFields(Client client, + Map headers, + String index, + ActionListener listener) { + // Step 2. Extract fields (if possible) and notify listener + ActionListener fieldCapabilitiesHandler = ActionListener.wrap( + fieldCapabilitiesResponse -> listener.onResponse(detectExtractedFields(index, fieldCapabilitiesResponse)), + e -> { + if (e instanceof IndexNotFoundException) { + listener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " + + ((IndexNotFoundException) e).getIndex() + " does not exist")); + } else { + listener.onFailure(e); + } + } + ); + + // Step 1. Get field capabilities necessary to build the information of how to extract fields + FieldCapabilitiesRequest fieldCapabilitiesRequest = new FieldCapabilitiesRequest(); + fieldCapabilitiesRequest.indices(index); + fieldCapabilitiesRequest.fields("*"); + ClientHelper.executeWithHeaders(headers, ClientHelper.ML_ORIGIN, client, () -> { + client.execute(FieldCapabilitiesAction.INSTANCE, fieldCapabilitiesRequest, fieldCapabilitiesHandler); + // This response gets discarded - the listener handles the real response + return null; + }); + } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 85e07d63cba4a..4617b7c59a322 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -25,6 +25,25 @@ - match: { count: 0 } - match: { data_frame_analytics: [] } +--- +"Test put valid config with default outlier detection and query": + + - do: + ml.put_data_frame_analytics: + id: "simple-outlier-detection-with-query" + body: > + { + "source": "source_index", + "dest": "dest_index", + "analyses": [{"outlier_detection":{}}], + "query": {"term" : { "user" : "Kimchy" }} + } + - match: { id: "simple-outlier-detection-with-query" } + - match: { source: "source_index" } + - match: { dest: "dest_index" } + - match: { analyses: [{"outlier_detection":{}}] } + - match: { query: {"term" : { "user" : "Kimchy"} } } + --- "Test put valid config with default outlier detection": @@ -41,6 +60,7 @@ - match: { source: "source_index" } - match: { dest: "dest_index" } - match: { analyses: [{"outlier_detection":{}}] } + - match: { query: {"match_all" : {} } } --- "Test put config with inconsistent body/param ids": From 2b6e6c073a28b06868d0bae85c34867572101fa5 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 12 Feb 2019 17:40:51 +0200 Subject: [PATCH 18/67] Revert "change the download location of the ml native code build (#36733)" This reverts commit 9bb2a02f44c4b719902ab5d254902afa15514804. --- x-pack/plugin/ml/cpp-snapshot/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/cpp-snapshot/build.gradle b/x-pack/plugin/ml/cpp-snapshot/build.gradle index 9df65d02d4afd..e5b55293159aa 100644 --- a/x-pack/plugin/ml/cpp-snapshot/build.gradle +++ b/x-pack/plugin/ml/cpp-snapshot/build.gradle @@ -8,7 +8,7 @@ ext.version = VersionProperties.elasticsearch // for this project so it can be used with dependency substitution. void getZip(File snapshotZip) { - String zipUrl = "http://prelert-artifacts.s3.amazonaws.com/maven/org/elasticsearch/ml/ml-cpp-df/${version}/ml-cpp-df-${version}.zip" + String zipUrl = "http://prelert-artifacts.s3.amazonaws.com/maven/org/elasticsearch/ml/ml-cpp/${version}/ml-cpp-${version}.zip" File snapshotMd5 = new File(snapshotZip.toString() + '.md5') HttpURLConnection conn = (HttpURLConnection) new URL(zipUrl).openConnection(); From a2ab55b9472db1eeb76b796b690a3aaf9ec2726f Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 15 Feb 2019 16:14:07 -0600 Subject: [PATCH 19/67] [FEATURE][ML] User config appropriate permission checks on creating/running analytics (#38928) * [Feature][ML] Add authz check for dataframe source index * fixing origin for client calls and adding headers * addressing PR comments * Having bulk request be done with headers in origin * addressing pr comments and failing test * making analyses immutable * adjusting indexnames and privs for security tests --- .../ml/dataframe/DataFrameAnalysisConfig.java | 6 +- .../dataframe/DataFrameAnalyticsConfig.java | 43 ++++++++- .../DataFrameAnalyticsConfigTests.java | 42 ++++++++- .../ml/qa/ml-with-security/build.gradle | 1 + .../plugin/ml/qa/ml-with-security/roles.yml | 4 +- .../TransportPutDataFrameAnalyticsAction.java | 83 +++++++++++++++-- ...ransportStartDataFrameAnalyticsAction.java | 3 +- .../dataframe/DataFrameAnalyticsManager.java | 30 ++++-- .../extractor/DataFrameDataExtractor.java | 6 ++ .../DataFrameDataExtractorFactory.java | 17 ++-- .../DataFrameAnalyticsConfigProvider.java | 18 +++- .../process/DataFrameRowsJoiner.java | 7 +- .../process/DataFrameRowsJoinerTests.java | 7 ++ .../test/ml/data_frame_analytics_crud.yml | 91 +++++++++++-------- 14 files changed, 284 insertions(+), 74 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java index acdb21b44dadf..cc4106218e098 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java @@ -14,6 +14,8 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -26,14 +28,14 @@ public static ContextParser parser() { private final Map config; public DataFrameAnalysisConfig(Map config) { - this.config = Objects.requireNonNull(config); + this.config = Collections.unmodifiableMap(new HashMap<>(Objects.requireNonNull(config))); if (config.size() != 1) { throw ExceptionsHelper.badRequestException("A data frame analysis must specify exactly one analysis type"); } } public DataFrameAnalysisConfig(StreamInput in) throws IOException { - config = in.readMap(); + config = Collections.unmodifiableMap(in.readMap()); } public Map asMap() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index d982dc3bfa91a..976774ed59371 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -29,6 +29,8 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -62,6 +64,7 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { public static final ParseField ANALYSES = new ParseField("analyses"); public static final ParseField CONFIG_TYPE = new ParseField("config_type"); public static final ParseField QUERY = new ParseField("query"); + public static final ParseField HEADERS = new ParseField("headers"); public static final ObjectParser STRICT_PARSER = createParser(false); public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -75,6 +78,11 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel parser.declareString(Builder::setDest, DEST); parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES); parser.declareObject((builder, query) -> builder.setQuery(query, ignoreUnknownFields), (p, c) -> p.mapOrdered(), QUERY); + if (ignoreUnknownFields) { + // Headers are not parsed by the strict (config) parser, so headers supplied in the _body_ of a REST request will be rejected. + // (For config, headers are explicitly transferred from the auth headers by code in the put data frame actions.) + parser.declareObject(Builder::setHeaders, (p, c) -> p.mapStrings(), HEADERS); + } return parser; } @@ -84,9 +92,10 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel private final List analyses; private final Map query; private final CachedSupplier querySupplier; + private final Map headers; public DataFrameAnalyticsConfig(String id, String source, String dest, List analyses, - Map query) { + Map query, Map headers) { this.id = ExceptionsHelper.requireNonNull(id, ID); this.source = ExceptionsHelper.requireNonNull(source, SOURCE); this.dest = ExceptionsHelper.requireNonNull(dest, DEST); @@ -100,6 +109,7 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); + this.headers = Collections.unmodifiableMap(headers); } public DataFrameAnalyticsConfig(StreamInput in) throws IOException { @@ -109,6 +119,7 @@ public DataFrameAnalyticsConfig(StreamInput in) throws IOException { analyses = in.readList(DataFrameAnalysisConfig::new); this.query = in.readMap(); this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); + this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); } public String getId() { @@ -151,6 +162,10 @@ List getQueryDeprecations(TriFunction, String, List< return deprecations; } + public Map getHeaders() { + return headers; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -162,6 +177,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(CONFIG_TYPE.getPreferredName(), TYPE); } builder.field(QUERY.getPreferredName(), query); + if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(HEADERS.getPreferredName(), headers); + } builder.endObject(); return builder; } @@ -173,6 +191,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(dest); out.writeList(analyses); out.writeMap(query); + out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); } @Override @@ -185,12 +204,13 @@ public boolean equals(Object o) { && Objects.equals(source, other.source) && Objects.equals(dest, other.dest) && Objects.equals(analyses, other.analyses) + && Objects.equals(headers, other.headers) && Objects.equals(query, other.query); } @Override public int hashCode() { - return Objects.hash(id, source, dest, analyses, query); + return Objects.hash(id, source, dest, analyses, query, headers); } public static String documentId(String id) { @@ -204,11 +224,23 @@ public static class Builder { private String dest; private List analyses; private Map query = Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()); + private Map headers = Collections.emptyMap(); public String getId() { return id; } + public Builder() {} + + public Builder(DataFrameAnalyticsConfig config) { + this.id = config.id; + this.source = config.source; + this.dest = config.dest; + this.analyses = new ArrayList<>(config.analyses); + this.query = new LinkedHashMap<>(config.query); + this.headers = new HashMap<>(config.headers); + } + public Builder setId(String id) { this.id = ExceptionsHelper.requireNonNull(id, ID); return this; @@ -248,8 +280,13 @@ public Builder setQuery(Map query, boolean lenient) { return this; } + public Builder setHeaders(Map headers) { + this.headers = headers; + return this; + } + public DataFrameAnalyticsConfig build() { - return new DataFrameAnalyticsConfig(id, source, dest, analyses, query); + return new DataFrameAnalyticsConfig(id, source, dest, analyses, query, headers); } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index d58c0e5355987..987b0deceb77d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -7,12 +7,16 @@ import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; @@ -21,13 +25,18 @@ import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -61,6 +70,10 @@ protected Writeable.Reader instanceReader() { } public static DataFrameAnalyticsConfig createRandom(String id) { + return createRandomBuilder(id).build(); + } + + public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) { String source = randomAlphaOfLength(10); String dest = randomAlphaOfLength(10); List analyses = Collections.singletonList(DataFrameAnalysisConfigTests.randomConfig()); @@ -74,7 +87,7 @@ public static DataFrameAnalyticsConfig createRandom(String id) { Collections.singletonMap(TermQueryBuilder.NAME, Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))), true); } - return builder.build(); + return builder; } public static String randomValidId() { @@ -142,6 +155,33 @@ public void testPastQueryConfigParse() throws IOException { } } + public void testToXContentForInternalStorage() throws IOException { + DataFrameAnalyticsConfig.Builder builder = createRandomBuilder("foo"); + + // headers are only persisted to cluster state + Map headers = new HashMap<>(); + headers.put("header-name", "header-value"); + builder.setHeaders(headers); + DataFrameAnalyticsConfig config = builder.build(); + + ToXContent.MapParams params = new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + + BytesReference forClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, params, false); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, forClusterstateXContent.streamInput()); + + DataFrameAnalyticsConfig parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(parsedConfig.getHeaders(), hasEntry("header-name", "header-value")); + + // headers are not written without the FOR_INTERNAL_STORAGE param + BytesReference nonClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); + parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, nonClusterstateXContent.streamInput()); + + parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(parsedConfig.getHeaders().entrySet(), hasSize(0)); + } + public void testGetQueryDeprecations() { DataFrameAnalyticsConfig dataFrame = createTestInstance(); String deprecationWarning = "Warning"; diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 2ddcd3f276027..2003f3c18aa2e 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -35,6 +35,7 @@ integTestRunner { 'ml/datafeeds_crud/Test put datafeed with invalid query', 'ml/datafeeds_crud/Test put datafeed with security headers in the body', 'ml/datafeeds_crud/Test update datafeed with missing id', + 'ml/data_frame_analytics_crud/Test put config with security headers in the body', 'ml/data_frame_analytics_crud/Test put config with inconsistent body/param ids', 'ml/data_frame_analytics_crud/Test put config with invalid id', 'ml/data_frame_analytics_crud/Test put config with unknown top level field', diff --git a/x-pack/plugin/ml/qa/ml-with-security/roles.yml b/x-pack/plugin/ml/qa/ml-with-security/roles.yml index e47fe40a120cd..8533b81c07377 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/roles.yml +++ b/x-pack/plugin/ml/qa/ml-with-security/roles.yml @@ -11,7 +11,7 @@ minimal: privileges: - indices:admin/create - indices:admin/refresh - - indices:data/read/field_caps - - indices:data/read/search + - read + - index - indices:data/write/bulk - indices:data/write/index diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java index e96ca02ce4f89..d0fbc613896f5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java @@ -8,20 +8,35 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.elasticsearch.xpack.core.security.SecurityContext; +import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesAction; +import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesRequest; +import org.elasticsearch.xpack.core.security.action.user.HasPrivilegesResponse; +import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; +import org.elasticsearch.xpack.core.security.authz.permission.ResourcePrivileges; +import org.elasticsearch.xpack.core.security.support.Exceptions; import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; +import java.io.IOException; import java.util.function.Supplier; public class TransportPutDataFrameAnalyticsAction @@ -29,14 +44,22 @@ public class TransportPutDataFrameAnalyticsAction private final XPackLicenseState licenseState; private final DataFrameAnalyticsConfigProvider configProvider; + private final ThreadPool threadPool; + private final SecurityContext securityContext; + private final Client client; @Inject - public TransportPutDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters, - XPackLicenseState licenseState, DataFrameAnalyticsConfigProvider configProvider) { + public TransportPutDataFrameAnalyticsAction(Settings settings, TransportService transportService, ActionFilters actionFilters, + XPackLicenseState licenseState, Client client, ThreadPool threadPool, + DataFrameAnalyticsConfigProvider configProvider) { super(PutDataFrameAnalyticsAction.NAME, transportService, actionFilters, (Supplier) PutDataFrameAnalyticsAction.Request::new); this.licenseState = licenseState; this.configProvider = configProvider; + this.threadPool = threadPool; + this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ? + new SecurityContext(settings, threadPool.getThreadContext()) : null; + this.client = client; } @Override @@ -46,12 +69,58 @@ protected void doExecute(Task task, PutDataFrameAnalyticsAction.Request request, listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); return; } - validateConfig(request.getConfig()); - configProvider.put(request.getConfig(), ActionListener.wrap( - indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())), - listener::onFailure - )); + if (licenseState.isAuthAllowed()) { + final String username = securityContext.getUser().principal(); + RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() + .indices(request.getConfig().getSource()) + .privileges("read") + .build(); + RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() + .indices(request.getConfig().getDest()) + .privileges("read", "index", "create_index") + .build(); + + HasPrivilegesRequest privRequest = new HasPrivilegesRequest(); + privRequest.applicationPrivileges(new RoleDescriptor.ApplicationResourcePrivileges[0]); + privRequest.username(username); + privRequest.clusterPrivileges(Strings.EMPTY_ARRAY); + privRequest.indexPrivileges(sourceIndexPrivileges, destIndexPrivileges); + + ActionListener privResponseListener = ActionListener.wrap( + r -> handlePrivsResponse(username, request, r, listener), + listener::onFailure); + + client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener); + } else { + configProvider.put(request.getConfig(), threadPool.getThreadContext().getHeaders(), ActionListener.wrap( + indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())), + listener::onFailure + )); + } + } + + private void handlePrivsResponse(String username, PutDataFrameAnalyticsAction.Request request, + HasPrivilegesResponse response, + ActionListener listener) throws IOException { + if (response.isCompleteMatch()) { + configProvider.put(request.getConfig(), threadPool.getThreadContext().getHeaders(), ActionListener.wrap( + indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())), + listener::onFailure + )); + } else { + XContentBuilder builder = JsonXContent.contentBuilder(); + builder.startObject(); + for (ResourcePrivileges index : response.getIndexPrivileges()) { + builder.field(index.getResource()); + builder.map(index.getPrivileges()); + } + builder.endObject(); + + listener.onFailure(Exceptions.authorizationError("Cannot create data frame analytics [{}]" + + " because user {} lacks permissions on the indices: {}", + request.getConfig().getId(), username, Strings.toString(builder))); + } } private void validateConfig(DataFrameAnalyticsConfig config) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index f518297915d1d..5b18eeda7d277 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -41,7 +41,6 @@ import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; -import java.util.Collections; import java.util.Map; import java.util.Objects; import java.util.function.Predicate; @@ -131,7 +130,7 @@ public void onFailure(Exception e) { // Validate config ActionListener configListener = ActionListener.wrap( config -> - DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, Collections.emptyMap(), config, validateListener), + DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, config, validateListener), listener::onFailure ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index ced87744d4eaa..0395f3115bbc3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -25,6 +25,7 @@ import org.elasticsearch.index.reindex.ReindexRequest; import org.elasticsearch.script.Script; import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; @@ -34,7 +35,6 @@ import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -99,7 +99,12 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF // Refresh to ensure copied index is fully searchable ActionListener reindexCompletedListener = ActionListener.wrap( - bulkResponse -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest()), refreshListener), + bulkResponse -> + ClientHelper.executeAsyncWithOrigin(client, + ClientHelper.ML_ORIGIN, + RefreshAction.INSTANCE, + new RefreshRequest(config.getDest()), + refreshListener), e -> task.markAsFailed(e) ); @@ -112,12 +117,17 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF reindexRequest.setSourceQuery(config.getParsedQuery()); reindexRequest.setDestIndex(config.getDest()); reindexRequest.setScript(new Script("ctx._source." + DataFrameAnalyticsFields.ID + " = ctx._id")); - client.execute(ReindexAction.INSTANCE, reindexRequest, reindexCompletedListener); + ClientHelper.executeWithHeadersAsync(config.getHeaders(), + ClientHelper.ML_ORIGIN, + client, + ReindexAction.INSTANCE, + reindexRequest, + reindexCompletedListener); }, reindexCompletedListener::onFailure ); - createDestinationIndex(config.getSource(), config.getDest(), copyIndexCreatedListener); + createDestinationIndex(config.getSource(), config.getDest(), config.getHeaders(), copyIndexCreatedListener); } private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) { @@ -144,10 +154,11 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi // TODO This could fail with errors. In that case we get stuck with the copied index. // We could delete the index in case of failure or we could try building the factory before reindexing // to catch the error early on. - DataFrameDataExtractorFactory.create(client, Collections.emptyMap(), config, dataExtractorFactoryListener); + DataFrameDataExtractorFactory.create(client, config, dataExtractorFactoryListener); } - private void createDestinationIndex(String sourceIndex, String destinationIndex, ActionListener listener) { + private void createDestinationIndex(String sourceIndex, String destinationIndex, Map headers, + ActionListener listener) { IndexMetaData indexMetaData = clusterService.state().getMetaData().getIndices().get(sourceIndex); if (indexMetaData == null) { listener.onFailure(new IndexNotFoundException(sourceIndex)); @@ -161,7 +172,12 @@ private void createDestinationIndex(String sourceIndex, String destinationIndex, CreateIndexRequest createIndexRequest = new CreateIndexRequest(destinationIndex, settingsBuilder.build()); addDestinationIndexMappings(indexMetaData, createIndexRequest); - client.execute(CreateIndexAction.INSTANCE, createIndexRequest, listener); + ClientHelper.executeWithHeadersAsync(headers, + ClientHelper.ML_ORIGIN, + client, + CreateIndexAction.INSTANCE, + createIndexRequest, + listener); } private static void addDestinationIndexMappings(IndexMetaData indexMetaData, CreateIndexRequest createIndexRequest) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 055a5e7d8dd64..7f01f800d1d71 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -27,7 +27,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; import java.util.Optional; @@ -59,6 +61,10 @@ public class DataFrameDataExtractor { searchHasShardFailure = false; } + public Map getHeaders() { + return Collections.unmodifiableMap(context.headers); + } + public boolean hasNext() { return hasNext; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 3d076ace1e348..2f0db834d8578 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -58,12 +58,15 @@ public class DataFrameDataExtractorFactory { private final String analyticsId; private final String index; private final ExtractedFields extractedFields; + private final Map headers; - private DataFrameDataExtractorFactory(Client client, String analyticsId, String index, ExtractedFields extractedFields) { + private DataFrameDataExtractorFactory(Client client, String analyticsId, String index, ExtractedFields extractedFields, + Map headers) { this.client = Objects.requireNonNull(client); this.analyticsId = Objects.requireNonNull(analyticsId); this.index = Objects.requireNonNull(index); this.extractedFields = Objects.requireNonNull(extractedFields); + this.headers = headers; } public DataFrameDataExtractor newExtractor(boolean includeSource) { @@ -73,7 +76,7 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { Arrays.asList(index), QueryBuilders.matchAllQuery(), 1000, - Collections.emptyMap(), + headers, includeSource ); return new DataFrameDataExtractor(client, context); @@ -85,18 +88,16 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { * The destination index must exist and contain at least 1 compatible field or validations will fail. * * @param client ES Client used to make calls against the cluster - * @param headers Headers to use * @param config The config from which to create the extractor factory * @param listener The listener to notify on creation or failure */ public static void create(Client client, - Map headers, DataFrameAnalyticsConfig config, ActionListener listener) { - validateIndexAndExtractFields(client, headers, config.getDest(), ActionListener.wrap( + validateIndexAndExtractFields(client, config.getHeaders(), config.getDest(), ActionListener.wrap( extractedFields -> listener.onResponse( - new DataFrameDataExtractorFactory(client, config.getId(), config.getDest(), extractedFields)), + new DataFrameDataExtractorFactory(client, config.getId(), config.getDest(), extractedFields, config.getHeaders())), listener::onFailure )); } @@ -105,15 +106,13 @@ public static void create(Client client, * Validates the source index and analytics config * * @param client ES Client to make calls - * @param headers Headers for auth * @param config Analytics config to validate * @param listener The listener to notify on failure or completion */ public static void validateConfigAndSourceIndex(Client client, - Map headers, DataFrameAnalyticsConfig config, ActionListener listener) { - validateIndexAndExtractFields(client, headers, config.getSource(), ActionListener.wrap( + validateIndexAndExtractFields(client, config.getHeaders(), config.getSource(), ActionListener.wrap( fields -> { config.getParsedQuery(); // validate query is acceptable listener.onResponse(true); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java index ed340d155a8fd..5ae2358bcb2fb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -29,6 +30,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -40,6 +42,7 @@ public class DataFrameAnalyticsConfigProvider { static { Map modifiable = new HashMap<>(); modifiable.put(ToXContentParams.INCLUDE_TYPE, "true"); + modifiable.put(ToXContentParams.FOR_INTERNAL_STORAGE, "true"); TO_XCONTENT_PARAMS = Collections.unmodifiableMap(modifiable); } @@ -49,7 +52,18 @@ public DataFrameAnalyticsConfigProvider(Client client) { this.client = Objects.requireNonNull(client); } - public void put(DataFrameAnalyticsConfig config, ActionListener listener) { + public void put(DataFrameAnalyticsConfig config, Map headers, ActionListener listener) { + String id = config.getId(); + + if (headers.isEmpty() == false) { + // Filter any values in headers that aren't security fields + DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(config); + Map securityHeaders = headers.entrySet().stream() + .filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + builder.setHeaders(securityHeaders); + config = builder.build(); + } try (XContentBuilder builder = XContentFactory.jsonBuilder()) { config.toXContent(builder, new ToXContent.MapParams(TO_XCONTENT_PARAMS)); IndexRequest indexRequest = new IndexRequest(AnomalyDetectorsIndex.configIndexName()) @@ -62,7 +76,7 @@ public void put(DataFrameAnalyticsConfig config, ActionListener l listener::onResponse, e -> { if (e instanceof VersionConflictEngineException) { - listener.onFailure(ExceptionsHelper.dataFrameAnalyticsAlreadyExists(config.getId())); + listener.onFailure(ExceptionsHelper.dataFrameAnalyticsAlreadyExists(id)); } else { listener.onFailure(e); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java index 76ebe166a39ad..a86645e4fc52d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.client.Client; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; @@ -105,7 +106,11 @@ private void joinCurrentResults() { bulkRequest.add(indexRequest); } if (bulkRequest.numberOfActions() > 0) { - BulkResponse bulkResponse = client.execute(BulkAction.INSTANCE, bulkRequest).actionGet(); + BulkResponse bulkResponse = + ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), + ClientHelper.ML_ORIGIN, + client, + () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()); if (bulkResponse.hasFailures()) { LOGGER.error("Failures while writing data frame"); // TODO Better error handling diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java index 4c6d9e78a9300..a4795c0ad7cc6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java @@ -13,9 +13,12 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.search.SearchHit; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.junit.Before; @@ -124,8 +127,12 @@ private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, } private void givenClientHasNoFailures() { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(threadContext); ActionFuture responseFuture = mock(ActionFuture.class); when(responseFuture.actionGet()).thenReturn(new BulkResponse(new BulkItemResponse[0], 0)); when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture); + when(client.threadPool()).thenReturn(threadPool); } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 4617b7c59a322..a8a1c10de0164 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -33,17 +33,32 @@ id: "simple-outlier-detection-with-query" body: > { - "source": "source_index", - "dest": "dest_index", + "source": "index-source", + "dest": "index-dest", "analyses": [{"outlier_detection":{}}], "query": {"term" : { "user" : "Kimchy" }} } - match: { id: "simple-outlier-detection-with-query" } - - match: { source: "source_index" } - - match: { dest: "dest_index" } + - match: { source: "index-source" } + - match: { dest: "index-dest" } - match: { analyses: [{"outlier_detection":{}}] } - match: { query: {"term" : { "user" : "Kimchy"} } } +--- +"Test put config with security headers in the body": + - do: + catch: /unknown field \[headers\], parser not found/ + ml.put_data_frame_analytics: + id: "data_frame_with_header" + body: > + { + "source": "index-source", + "dest": "index-dest", + "analyses": [{"outlier_detection":{}}], + "query": {"term" : { "user" : "Kimchy" }}, + "headers":{ "a_security_header" : "secret" } + } + --- "Test put valid config with default outlier detection": @@ -52,13 +67,13 @@ id: "simple-outlier-detection" body: > { - "source": "source_index", - "dest": "dest_index", + "source": "index-source", + "dest": "index-dest", "analyses": [{"outlier_detection":{}}] } - match: { id: "simple-outlier-detection" } - - match: { source: "source_index" } - - match: { dest: "dest_index" } + - match: { source: "index-source" } + - match: { dest: "index-dest" } - match: { analyses: [{"outlier_detection":{}}] } - match: { query: {"match_all" : {} } } @@ -72,8 +87,8 @@ body: > { "id": "body_id", - "source": "source_index", - "dest": "dest_index", + "source": "index-source", + "dest": "index-dest", "analyses": [{"outlier_detection":{}}] } @@ -86,8 +101,8 @@ id: "this id contains spaces" body: > { - "source": "source_index", - "dest": "dest_index", + "source": "index-source", + "dest": "index-dest", "analyses": [{"outlier_detection":{}}] } @@ -100,8 +115,8 @@ id: "unknown_field" body: > { - "source": "source_index", - "dest": "dest_index", + "source": "index-source", + "dest": "index-dest", "analyses": [{"outlier_detection":{}}], "unknown_field": 42 } @@ -115,8 +130,8 @@ id: "unknown_field" body: > { - "source": "source_index", - "dest": "dest_index", + "source": "index-source", + "dest": "index-dest", "analyses": [{"outlier_detection":{"unknown_field": 42}}] } @@ -129,7 +144,7 @@ id: "simple-outlier-detection" body: > { - "dest": "dest_index", + "dest": "index-dest", "analyses": [{"outlier_detection":{}}] } @@ -142,7 +157,7 @@ id: "simple-outlier-detection" body: > { - "source": "source_index", + "source": "index-source", "analyses": [{"outlier_detection":{}}] } @@ -155,8 +170,8 @@ id: "simple-outlier-detection" body: > { - "source": "source_index", - "dest": "dest_index" + "source": "index-source", + "dest": "index-dest" } --- @@ -168,8 +183,8 @@ id: "simple-outlier-detection" body: > { - "source": "source_index", - "dest": "dest_index", + "source": "index-source", + "dest": "index-dest", "analyses": [] } @@ -182,8 +197,8 @@ id: "simple-outlier-detection" body: > { - "source": "source_index", - "dest": "dest_index", + "source": "index-source", + "dest": "index-dest", "analyses": [{"outlier_detection":{}}, {"outlier_detection":{}}] } @@ -195,8 +210,8 @@ id: "foo-1" body: > { - "source": "foo-1_source", - "dest": "foo-1_dest", + "source": "index-foo-1_source", + "dest": "index-foo-1_dest", "analyses": [{"outlier_detection":{}}] } @@ -205,8 +220,8 @@ id: "foo-2" body: > { - "source": "foo-2_source", - "dest": "foo-2_dest", + "source": "index-foo-2_source", + "dest": "index-foo-2_dest", "analyses": [{"outlier_detection":{}}] } - match: { id: "foo-2" } @@ -216,8 +231,8 @@ id: "bar" body: > { - "source": "bar_source", - "dest": "bar_dest", + "source": "index-bar_source", + "dest": "index-bar_dest", "analyses": [{"outlier_detection":{}}] } - match: { id: "bar" } @@ -280,8 +295,8 @@ id: "foo-1" body: > { - "source": "foo-1_source", - "dest": "foo-1_dest", + "source": "index-foo-1_source", + "dest": "index-foo-1_dest", "analyses": [{"outlier_detection":{}}] } @@ -290,8 +305,8 @@ id: "foo-2" body: > { - "source": "foo-2_source", - "dest": "foo-2_dest", + "source": "index-foo-2_source", + "dest": "index-foo-2_dest", "analyses": [{"outlier_detection":{}}] } - match: { id: "foo-2" } @@ -301,8 +316,8 @@ id: "bar" body: > { - "source": "bar_source", - "dest": "bar_dest", + "source": "index-bar_source", + "dest": "index-bar_dest", "analyses": [{"outlier_detection":{}}] } - match: { id: "bar" } @@ -366,8 +381,8 @@ id: "foo" body: > { - "source": "source", - "dest": "dest", + "source": "index-source", + "dest": "index-dest", "analyses": [{"outlier_detection":{}}] } From deac81b44a281fe793798043e42aca1ac9678823 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 18 Feb 2019 19:04:07 +0200 Subject: [PATCH 20/67] [FEATURE][ML] Wire in data frame analytics progress reporting (#38808) Adds progress reporting. Progress is reported per state. In particular, this adds progress reporting for the reindexing state and the analyzing state. For reindexing, we now store the reindex task id and we use it to get the task info and calculate progress by taking into consideration the number of docs created against the total docs. For analyzing, we read the progress reported from the native process and store it in memory. The get tasks action has been changed to direct to the node running the process when possible. Then, progress is reported additionally to the rest of stats for running tasks. This commit adds integration tests on the multi-node environment. Those tests have revealed some issues which are also fixed here: - Registering named content correctly - Wait for task state to be `started` before responding in the start API --- .../xpack/core/XPackClientPlugin.java | 9 + .../elasticsearch/xpack/core/ml/MlTasks.java | 2 +- .../action/AbstractGetResourcesRequest.java | 8 +- .../action/GetDataFrameAnalyticsAction.java | 4 + .../GetDataFrameAnalyticsStatsAction.java | 62 +++++-- .../action/StartDataFrameAnalyticsAction.java | 16 ++ .../dataframe/DataFrameAnalyticsConfig.java | 10 +- .../DataFrameAnalyticsTaskState.java | 12 +- ...rameAnalyticsStatsActionResponseTests.java | 3 +- ...NativeDataFrameAnalyticsIntegTestCase.java | 112 ++++++++++++ .../ml/integration/MlNativeIntegTestCase.java | 6 + .../integration/RunDataFrameAnalyticsIT.java | 161 ++++++++++++++++++ .../xpack/ml/MachineLearning.java | 4 +- ...sportGetDataFrameAnalyticsStatsAction.java | 148 +++++++++++++--- ...ransportStartDataFrameAnalyticsAction.java | 106 +++++++++++- .../dataframe/DataFrameAnalyticsManager.java | 31 ++-- .../process/AnalyticsProcessManager.java | 55 ++++-- .../ml/dataframe/process/AnalyticsResult.java | 20 ++- .../process/AnalyticsResultProcessor.java | 8 +- .../AnalyticsResultProcessorTests.java | 6 +- .../process/AnalyticsResultTests.java | 6 +- 21 files changed, 699 insertions(+), 90 deletions(-) create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index f39c08dc5d6a5..2f2a99010a99d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -73,6 +73,7 @@ import org.elasticsearch.xpack.core.ml.action.CloseJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction; +import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction; import org.elasticsearch.xpack.core.ml.action.DeleteExpiredDataAction; import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction; @@ -125,6 +126,7 @@ import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage; import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage; @@ -304,6 +306,7 @@ public List> getClientActions() { PutDataFrameAnalyticsAction.INSTANCE, GetDataFrameAnalyticsAction.INSTANCE, GetDataFrameAnalyticsStatsAction.INSTANCE, + DeleteDataFrameAnalyticsAction.INSTANCE, StartDataFrameAnalyticsAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, @@ -388,9 +391,13 @@ public List getNamedWriteables() { StartDatafeedAction.DatafeedParams::new), new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_TASK_NAME, OpenJobAction.JobParams::new), + new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, + StartDataFrameAnalyticsAction.TaskParams::new), // ML - Task states new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new), new NamedWriteableRegistry.Entry(PersistentTaskState.class, DatafeedState.NAME, DatafeedState::fromStream), + new NamedWriteableRegistry.Entry(PersistentTaskState.class, DataFrameAnalyticsTaskState.NAME, + DataFrameAnalyticsTaskState::new), new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MACHINE_LEARNING, MachineLearningFeatureSetUsage::new), // monitoring @@ -467,6 +474,8 @@ public List getNamedXContent() { // ML - Task states new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(DatafeedState.NAME), DatafeedState::fromXContent), new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(JobTaskState.NAME), JobTaskState::fromXContent), + new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(DataFrameAnalyticsTaskState.NAME), + DataFrameAnalyticsTaskState::fromXContent), // watcher new NamedXContentRegistry.Entry(MetaData.Custom.class, new ParseField(WatcherMetaData.TYPE), WatcherMetaData::fromXContent), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index 649a77648eafb..8064abebc296b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -29,7 +29,7 @@ public final class MlTasks { public static final String JOB_TASK_ID_PREFIX = "job-"; public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-"; - private static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; + public static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; public static final PersistentTasksCustomMetaData.Assignment AWAITING_UPGRADE = new PersistentTasksCustomMetaData.Assignment(null, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java index baa7d2714ec8b..7d287557f5e93 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java @@ -19,19 +19,19 @@ public abstract class AbstractGetResourcesRequest extends ActionRequest { private String resourceId; private PageParams pageParams = PageParams.defaultParams(); - public void setResourceId(String resourceId) { + public final void setResourceId(String resourceId) { this.resourceId = resourceId; } - public String getResourceId() { + public final String getResourceId() { return resourceId; } - public void setPageParams(PageParams pageParams) { + public final void setPageParams(PageParams pageParams) { this.pageParams = pageParams; } - public PageParams getPageParams() { + public final PageParams getPageParams() { return pageParams; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java index 264b996b3e8f4..aeee3657604dd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java @@ -34,6 +34,10 @@ public static class Request extends AbstractGetResourcesRequest { public Request() {} + public Request(String id) { + setResourceId(id); + } + public Request(StreamInput in) throws IOException { readFrom(in); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index e7824bb08fb7a..3bf5ac10a9a0f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -5,11 +5,13 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.Action; import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Nullable; @@ -19,6 +21,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.core.ml.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -26,6 +29,8 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -48,13 +53,17 @@ public Writeable.Reader getResponseReader() { return Response::new; } - public static class Request extends MasterNodeRequest { + public static class Request extends BaseTasksRequest { private String id; private PageParams pageParams = PageParams.defaultParams(); + // Used internally to store the expanded IDs + private List expandedIds = Collections.emptyList(); + public Request(String id) { this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID.getPreferredName()); + this.expandedIds = Collections.singletonList(id); } public Request() {} @@ -63,6 +72,15 @@ public Request(StreamInput in) throws IOException { super(in); id = in.readString(); pageParams = in.readOptionalWriteable(PageParams::new); + expandedIds = in.readStringList(); + } + + public void setExpandedIds(List expandedIds) { + this.expandedIds = Objects.requireNonNull(expandedIds); + } + + public List getExpandedIds() { + return expandedIds; } @Override @@ -70,6 +88,7 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(id); out.writeOptionalWriteable(pageParams); + out.writeStringCollection(expandedIds); } public void setId(String id) { @@ -88,6 +107,11 @@ public PageParams getPageParams() { return pageParams; } + @Override + public boolean match(Task task) { + return expandedIds.stream().anyMatch(expandedId -> StartDataFrameAnalyticsAction.TaskMatcher.match(task, expandedId)); + } + @Override public ActionRequestValidationException validate() { return null; @@ -118,21 +142,24 @@ public RequestBuilder(ElasticsearchClient client, GetDataFrameAnalyticsStatsActi } } - public static class Response extends ActionResponse implements ToXContentObject { + public static class Response extends BaseTasksResponse implements ToXContentObject { public static class Stats implements ToXContentObject, Writeable { private final String id; private final DataFrameAnalyticsState state; @Nullable + private final Integer progressPercentage; + @Nullable private final DiscoveryNode node; @Nullable private final String assignmentExplanation; - public Stats(String id, DataFrameAnalyticsState state, @Nullable DiscoveryNode node, - @Nullable String assignmentExplanation) { + public Stats(String id, DataFrameAnalyticsState state, @Nullable Integer progressPercentage, + @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { this.id = Objects.requireNonNull(id); this.state = Objects.requireNonNull(state); + this.progressPercentage = progressPercentage; this.node = node; this.assignmentExplanation = assignmentExplanation; } @@ -140,6 +167,7 @@ public Stats(String id, DataFrameAnalyticsState state, @Nullable DiscoveryNode n public Stats(StreamInput in) throws IOException { id = in.readString(); state = DataFrameAnalyticsState.fromStream(in); + progressPercentage = in.readOptionalInt(); node = in.readOptionalWriteable(DiscoveryNode::new); assignmentExplanation = in.readOptionalString(); } @@ -152,14 +180,6 @@ public DataFrameAnalyticsState getState() { return state; } - public DiscoveryNode getNode() { - return node; - } - - public String getAssignmentExplanation() { - return assignmentExplanation; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { // TODO: Have callers wrap the content with an object as they choose rather than forcing it upon them @@ -173,6 +193,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOException { builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); builder.field("state", state.toString()); + if (progressPercentage != null) { + builder.field("progress_percent", progressPercentage); + } if (node != null) { builder.startObject("node"); builder.field("id", node.getId()); @@ -197,13 +220,14 @@ public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOExc public void writeTo(StreamOutput out) throws IOException { out.writeString(id); state.writeTo(out); + out.writeOptionalInt(progressPercentage); out.writeOptionalWriteable(node); out.writeOptionalString(assignmentExplanation); } @Override public int hashCode() { - return Objects.hash(id, state, node, assignmentExplanation); + return Objects.hash(id, state, progressPercentage, node, assignmentExplanation); } @Override @@ -224,9 +248,13 @@ public boolean equals(Object obj) { private QueryPage stats; - public Response() {} - public Response(QueryPage stats) { + this(Collections.emptyList(), Collections.emptyList(), stats); + } + + public Response(List taskFailures, List nodeFailures, + QueryPage stats) { + super(taskFailures, nodeFailures); this.stats = stats; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java index 2a3f7fbd008a2..2a7965261856f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.MasterNodeRequest; import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -19,6 +20,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -162,4 +164,18 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } } + + public interface TaskMatcher { + + static boolean match(Task task, String expectedId) { + if (task instanceof TaskMatcher) { + if (MetaData.ALL.equals(expectedId)) { + return true; + } + String expectedDescription = MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + expectedId; + return expectedDescription.equals(task.getDescription()); + } + return false; + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index 976774ed59371..8bc3202aa319b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -226,13 +226,17 @@ public static class Builder { private Map query = Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()); private Map headers = Collections.emptyMap(); + public Builder() {} + + public Builder(String id) { + setId(id); + } + public String getId() { return id; } - public Builder() {} - - public Builder(DataFrameAnalyticsConfig config) { + public Builder(DataFrameAnalyticsConfig config) { this.id = config.id; this.source = config.source; this.dest = config.dest; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java index 5d9b7ba756190..994faaaee6cc2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.dataframe; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; @@ -20,6 +21,8 @@ public class DataFrameAnalyticsTaskState implements PersistentTaskState { + public static final String NAME = MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; + private static ParseField STATE = new ParseField("state"); private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); @@ -27,7 +30,7 @@ public class DataFrameAnalyticsTaskState implements PersistentTaskState { private final long allocationId; private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, + new ConstructingObjectParser<>(NAME, true, a -> new DataFrameAnalyticsTaskState((DataFrameAnalyticsState) a[0], (long) a[1])); static { @@ -53,6 +56,11 @@ public DataFrameAnalyticsTaskState(DataFrameAnalyticsState state, long allocatio this.allocationId = allocationId; } + public DataFrameAnalyticsTaskState(StreamInput in) throws IOException { + this.state = DataFrameAnalyticsState.fromStream(in); + this.allocationId = in.readLong(); + } + public DataFrameAnalyticsState getState() { return state; } @@ -63,7 +71,7 @@ public boolean isStatusStale(PersistentTasksCustomMetaData.PersistentTask tas @Override public String getWriteableName() { - return MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; + return NAME; } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java index 65c51faa9157e..ed9599b05152a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -22,8 +22,9 @@ protected Response createTestInstance() { int listSize = randomInt(10); List analytics = new ArrayList<>(listSize); for (int j = 0; j < listSize; j++) { + Integer progressPercentage = randomBoolean() ? null : randomIntBetween(0, 100); Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), - randomFrom(DataFrameAnalyticsState.values()), null, randomAlphaOfLength(20)); + randomFrom(DataFrameAnalyticsState.values()), progressPercentage, null, randomAlphaOfLength(20)); analytics.add(stats); } return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java new file mode 100644 index 0000000000000..c82f5760c637b --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import static org.hamcrest.Matchers.equalTo; + +/** + * Base class of ML integration tests that use a native data_frame_analytics process + */ +abstract class MlNativeDataFrameAnalyticsIntegTestCase extends MlNativeIntegTestCase { + + private List analytics = new ArrayList<>(); + + @Override + protected void cleanUpResources() { + cleanUpAnalytics(); + } + + private void cleanUpAnalytics() { + for (DataFrameAnalyticsConfig config : analytics) { + try { + deleteAnalytics(config.getId()); + } catch (Exception e) { + // ignore + } + } + } + + protected void registerAnalytics(DataFrameAnalyticsConfig config) { + if (analytics.add(config) == false) { + throw new IllegalArgumentException("analytics config [" + config.getId() + "] is already registered"); + } + } + + protected PutDataFrameAnalyticsAction.Response putAnalytics(DataFrameAnalyticsConfig config) { + PutDataFrameAnalyticsAction.Request request = new PutDataFrameAnalyticsAction.Request(config); + return client().execute(PutDataFrameAnalyticsAction.INSTANCE, request).actionGet(); + } + + protected AcknowledgedResponse deleteAnalytics(String id) { + DeleteDataFrameAnalyticsAction.Request request = new DeleteDataFrameAnalyticsAction.Request(id); + return client().execute(DeleteDataFrameAnalyticsAction.INSTANCE, request).actionGet(); + } + + protected AcknowledgedResponse startAnalytics(String id) { + StartDataFrameAnalyticsAction.Request request = new StartDataFrameAnalyticsAction.Request(id); + return client().execute(StartDataFrameAnalyticsAction.INSTANCE, request).actionGet(); + } + + protected void waitUntilAnalyticsIsStopped(String id) throws Exception { + waitUntilAnalyticsIsStopped(id, TimeValue.timeValueSeconds(30)); + } + + protected void waitUntilAnalyticsIsStopped(String id, TimeValue waitTime) throws Exception { + assertBusy(() -> assertThat(getAnalyticsStats(id).get(0).getState(), equalTo(DataFrameAnalyticsState.STOPPED)), + waitTime.getMillis(), TimeUnit.MILLISECONDS); + } + + protected List getAnalytics(String id) { + GetDataFrameAnalyticsAction.Request request = new GetDataFrameAnalyticsAction.Request(id); + return client().execute(GetDataFrameAnalyticsAction.INSTANCE, request).actionGet().getResources().results(); + } + + protected List getAnalyticsStats(String id) { + GetDataFrameAnalyticsStatsAction.Request request = new GetDataFrameAnalyticsStatsAction.Request(id); + GetDataFrameAnalyticsStatsAction.Response response = client().execute(GetDataFrameAnalyticsStatsAction.INSTANCE, request) + .actionGet(); + return response.getResponse().results(); + } + + protected List generateData(long timestamp, TimeValue bucketSpan, int bucketCount, + Function timeToCountFunction) throws IOException { + List data = new ArrayList<>(); + long now = timestamp; + for (int bucketIndex = 0; bucketIndex < bucketCount; bucketIndex++) { + for (int count = 0; count < timeToCountFunction.apply(bucketIndex); count++) { + Map record = new HashMap<>(); + record.put("time", now); + data.add(createJsonRecord(record)); + } + now += bucketSpan.getMillis(); + } + return data; + } + + protected static String createJsonRecord(Map keyValueMap) throws IOException { + return Strings.toString(JsonXContent.contentBuilder().map(keyValueMap)) + "\n"; + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java index f844d813cb5fa..3e371150105e4 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java @@ -29,8 +29,10 @@ import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.security.SecurityField; import org.elasticsearch.xpack.core.security.authc.TokenMetaData; @@ -112,10 +114,14 @@ protected void ensureClusterStateConsistency() throws IOException { entries.add(new NamedWriteableRegistry.Entry(MetaData.Custom.class, "ml", MlMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DATAFEED_TASK_NAME, StartDatafeedAction.DatafeedParams::new)); + entries.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, + StartDataFrameAnalyticsAction.TaskParams::new)); entries.add(new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_TASK_NAME, OpenJobAction.JobParams::new)); entries.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new)); entries.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, DatafeedState.NAME, DatafeedState::fromStream)); + entries.add(new NamedWriteableRegistry.Entry(PersistentTaskState.class, DataFrameAnalyticsTaskState.NAME, + DataFrameAnalyticsTaskState::new)); entries.add(new NamedWriteableRegistry.Entry(ClusterState.Custom.class, TokenMetaData.TYPE, TokenMetaData::new)); final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(entries); ClusterState masterClusterState = client().admin().cluster().prepareState().all().get().getState(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java new file mode 100644 index 0000000000000..45395403a3a1a --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.junit.After; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTestCase { + + @After + public void cleanup() { + cleanUp(); + } + + public void testOutlierDetectionWithFewDocuments() throws Exception { + String sourceIndex = "test-outlier-detection-with-few-docs"; + + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical_1", "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (int i = 0; i < 5; i++) { + IndexRequest indexRequest = new IndexRequest(sourceIndex); + + // We insert one odd value out of 5 for one feature + String docId = i == 0 ? "outlier" : "normal" + i; + indexRequest.id(docId); + indexRequest.source("numeric_1", i == 0 ? 100.0 : 1.0, "numeric_2", 1.0, "categorical_1", "foo_" + i); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + + String id = "test_outlier_detection_with_few_docs"; + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); + registerAnalytics(config); + putAnalytics(config); + + assertState(id, DataFrameAnalyticsState.STOPPED); + + startAnalytics(id); + waitUntilAnalyticsIsStopped(id); + + SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); + double scoreOfOutlier = 0.0; + double scoreOfNonOutlier = -1.0; + for (SearchHit hit : sourceData.getHits()) { + GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest()).setId(hit.getId()).get(); + assertThat(destDocGetResponse.isExists(), is(true)); + Map sourceDoc = hit.getSourceAsMap(); + Map destDoc = destDocGetResponse.getSource(); + for (String field : sourceDoc.keySet()) { + assertThat(destDoc.containsKey(field), is(true)); + assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); + } + assertThat(destDoc.containsKey("outlier_score"), is(true)); + double outlierScore = (double) destDoc.get("outlier_score"); + assertThat(outlierScore, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(100.0))); + if (hit.getId().equals("outlier")) { + scoreOfOutlier = outlierScore; + } else { + if (scoreOfNonOutlier < 0) { + scoreOfNonOutlier = outlierScore; + } else { + assertThat(outlierScore, equalTo(scoreOfNonOutlier)); + } + } + } + assertThat(scoreOfOutlier, is(greaterThan(scoreOfNonOutlier))); + } + + public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { + String sourceIndex = "test-outlier-detection-with-enough-docs-to-scroll"; + + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical_1", "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + int docCount = randomIntBetween(1024, 2048); + for (int i = 0; i < docCount; i++) { + IndexRequest indexRequest = new IndexRequest(sourceIndex); + indexRequest.source("numeric_1", randomDouble(), "numeric_2", randomFloat(), "categorical_1", randomAlphaOfLength(10)); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + + String id = "test_outlier_detection_with_enough_docs_to_scroll"; + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); + registerAnalytics(config); + putAnalytics(config); + + assertState(id, DataFrameAnalyticsState.STOPPED); + + startAnalytics(id); + waitUntilAnalyticsIsStopped(id); + + // Check we've got all docs + SearchResponse searchResponse = client().prepareSearch(config.getDest()).setTrackTotalHits(true).get(); + assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); + + // Check they all have an outlier_score + searchResponse = client().prepareSearch(config.getDest()) + .setTrackTotalHits(true) + .setQuery(QueryBuilders.existsQuery("outlier_score")).get(); + assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); + } + + private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String sourceIndex) { + DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); + configBuilder.setSource(sourceIndex); + configBuilder.setDest(sourceIndex + "-results"); + Map analysisConfig = new HashMap<>(); + analysisConfig.put("outlier_detection", Collections.emptyMap()); + configBuilder.setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig(analysisConfig))); + return configBuilder.build(); + } + + private void assertState(String id, DataFrameAnalyticsState state) { + List stats = getAnalyticsStats(id); + assertThat(stats.size(), equalTo(1)); + assertThat(stats.get(0).getId(), equalTo(id)); + assertThat(stats.get(0).getState(), equalTo(state)); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index bf9893e7f74d4..2cafc7a39076b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; @@ -474,7 +475,8 @@ public Collection createComponents(Client client, ClusterService cluster AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, environment, threadPool, analyticsProcessFactory); DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client); - DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(clusterService, client, + assert client instanceof NodeClient; + DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(clusterService, (NodeClient) client, dataFrameAnalyticsConfigProvider, analyticsProcessManager); this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index 1a30200668d93..5f096df93bf8d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -5,67 +5,143 @@ */ package org.elasticsearch.xpack.ml.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskRequest; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.action.support.tasks.TransportTasksAction; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ClusterState; -import org.elasticsearch.cluster.block.ClusterBlockException; -import org.elasticsearch.cluster.block.ClusterBlockLevel; -import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.index.reindex.BulkByScrollTask; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskResult; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response.Stats; import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; public class TransportGetDataFrameAnalyticsStatsAction - extends TransportMasterNodeAction { + extends TransportTasksAction> { + + private static final Logger LOGGER = LogManager.getLogger(TransportGetDataFrameAnalyticsStatsAction.class); private final Client client; + private final AnalyticsProcessManager analyticsProcessManager; @Inject public TransportGetDataFrameAnalyticsStatsAction(TransportService transportService, ClusterService clusterService, Client client, - ThreadPool threadPool, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver) { - super(GetDataFrameAnalyticsStatsAction.NAME, transportService, clusterService, threadPool, actionFilters, - indexNameExpressionResolver, GetDataFrameAnalyticsStatsAction.Request::new); + ActionFilters actionFilters, AnalyticsProcessManager analyticsProcessManager) { + super(GetDataFrameAnalyticsStatsAction.NAME, clusterService, transportService, actionFilters, + GetDataFrameAnalyticsStatsAction.Request::new, GetDataFrameAnalyticsStatsAction.Response::new, + in -> new QueryPage<>(in, GetDataFrameAnalyticsStatsAction.Response.Stats::new), ThreadPool.Names.MANAGEMENT); this.client = client; + this.analyticsProcessManager = analyticsProcessManager; } @Override - protected String executor() { - return ThreadPool.Names.SAME; + protected GetDataFrameAnalyticsStatsAction.Response newResponse(GetDataFrameAnalyticsStatsAction.Request request, + List> tasks, + List taskFailures, + List nodeFailures) { + List stats = new ArrayList<>(); + for (QueryPage task : tasks) { + stats.addAll(task.results()); + } + Collections.sort(stats, Comparator.comparing(Stats::getId)); + return new GetDataFrameAnalyticsStatsAction.Response(taskFailures, nodeFailures, new QueryPage<>(stats, stats.size(), + GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); } @Override - protected GetDataFrameAnalyticsStatsAction.Response newResponse() { - return new GetDataFrameAnalyticsStatsAction.Response(); + protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, DataFrameAnalyticsTask task, + ActionListener> listener) { + LOGGER.debug("Get stats for running task [{}]", task.getParams().getId()); + + ActionListener progressListener = ActionListener.wrap( + progress -> { + Stats stats = buildStats(task.getParams().getId(), progress); + listener.onResponse(new QueryPage<>(Collections.singletonList(stats), 1, + GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); + }, listener::onFailure + ); + + ClusterState clusterState = clusterService.state(); + PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + DataFrameAnalyticsState analyticsState = MlTasks.getDataFrameAnalyticsState(task.getParams().getId(), tasks); + + // For a running task we report the progress associated with its current state + if (analyticsState == DataFrameAnalyticsState.REINDEXING) { + getReindexTaskProgress(task, progressListener); + } else { + progressListener.onResponse(analyticsProcessManager.getProgressPercent(task.getAllocationId())); + } + } + + private void getReindexTaskProgress(DataFrameAnalyticsTask task, ActionListener listener) { + TaskId reindexTaskId = new TaskId(clusterService.localNode().getId(), task.getReindexingTaskId()); + GetTaskRequest getTaskRequest = new GetTaskRequest(); + getTaskRequest.setTaskId(reindexTaskId); + client.admin().cluster().getTask(getTaskRequest, ActionListener.wrap( + taskResponse -> { + TaskResult taskResult = taskResponse.getTask(); + BulkByScrollTask.Status taskStatus = (BulkByScrollTask.Status) taskResult.getTask().getStatus(); + int progress = taskStatus.getTotal() == 0 ? 100 : (int) (taskStatus.getCreated() * 100.0 / taskStatus.getTotal()); + listener.onResponse(progress); + }, + error -> { + if (error instanceof ResourceNotFoundException) { + // The task has either not started yet or has finished, thus it is better to respond null and not show progress at all + listener.onResponse(null); + } else { + listener.onFailure(error); + } + } + )); } @Override - protected void masterOperation(GetDataFrameAnalyticsStatsAction.Request request, ClusterState state, - ActionListener listener) throws Exception { - PersistentTasksCustomMetaData tasks = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + protected void doExecute(Task task, GetDataFrameAnalyticsStatsAction.Request request, + ActionListener listener) { + LOGGER.debug("Get stats for data frame analytics [{}]", request.getId()); ActionListener getResponseListener = ActionListener.wrap( response -> { - List stats = new ArrayList(response.getResources().results().size()); - response.getResources().results().forEach(c -> stats.add(buildStats(c.getId(), tasks, state))); - listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>(stats, stats.size(), - GetDataFrameAnalyticsAction.Response.RESULTS_FIELD))); + List expandedIds = response.getResources().results().stream().map(DataFrameAnalyticsConfig::getId) + .collect(Collectors.toList()); + request.setExpandedIds(expandedIds); + ActionListener runningTasksStatsListener = ActionListener.wrap( + runningTasksStatsResponse -> gatherStatsForStoppedTasks(request.getExpandedIds(), runningTasksStatsResponse, listener), + listener::onFailure + ); + super.doExecute(task, request, runningTasksStatsListener); }, listener::onFailure ); @@ -76,8 +152,29 @@ protected void masterOperation(GetDataFrameAnalyticsStatsAction.Request request, executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, getRequest, getResponseListener); } - private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, PersistentTasksCustomMetaData tasks, - ClusterState clusterState) { + void gatherStatsForStoppedTasks(List expandedIds, GetDataFrameAnalyticsStatsAction.Response runningTasksResponse, + ActionListener listener) { + List stoppedTasksIds = determineStoppedTasksIds(expandedIds, runningTasksResponse.getResponse().results()); + List stoppedTasksStats = stoppedTasksIds.stream().map(this::buildStatsForStoppedTask).collect(Collectors.toList()); + List allTasksStats = new ArrayList<>(runningTasksResponse.getResponse().results()); + allTasksStats.addAll(stoppedTasksStats); + Collections.sort(allTasksStats, Comparator.comparing(Stats::getId)); + listener.onResponse(new GetDataFrameAnalyticsStatsAction.Response(new QueryPage<>( + allTasksStats, allTasksStats.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD))); + } + + static List determineStoppedTasksIds(List expandedIds, List runningTasksStats) { + Set startedTasksIds = runningTasksStats.stream().map(Stats::getId).collect(Collectors.toSet()); + return expandedIds.stream().filter(id -> startedTasksIds.contains(id) == false).collect(Collectors.toList()); + } + + private GetDataFrameAnalyticsStatsAction.Response.Stats buildStatsForStoppedTask(String concreteAnalyticsId) { + return buildStats(concreteAnalyticsId, null); + } + + private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concreteAnalyticsId, @Nullable Integer progressPercent) { + ClusterState clusterState = clusterService.state(); + PersistentTasksCustomMetaData tasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); PersistentTasksCustomMetaData.PersistentTask analyticsTask = MlTasks.getDataFrameAnalyticsTask(concreteAnalyticsId, tasks); DataFrameAnalyticsState analyticsState = MlTasks.getDataFrameAnalyticsState(concreteAnalyticsId, tasks); DiscoveryNode node = null; @@ -87,11 +184,6 @@ private GetDataFrameAnalyticsStatsAction.Response.Stats buildStats(String concre assignmentExplanation = analyticsTask.getAssignment().getExplanation(); } return new GetDataFrameAnalyticsStatsAction.Response.Stats( - concreteAnalyticsId, analyticsState, node, assignmentExplanation); - } - - @Override - protected ClusterBlockException checkBlock(GetDataFrameAnalyticsStatsAction.Request request, ClusterState state) { - return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_READ); + concreteAnalyticsId, analyticsState, progressPercent, node, assignmentExplanation); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 5b18eeda7d277..cb7e84651ac73 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -5,6 +5,9 @@ */ package org.elasticsearch.xpack.ml.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.action.ActionListener; @@ -18,10 +21,13 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTaskParams; import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.persistent.PersistentTasksExecutor; @@ -53,6 +59,8 @@ public class TransportStartDataFrameAnalyticsAction extends TransportMasterNodeAction { + private static final Logger LOGGER = LogManager.getLogger(TransportStartDataFrameAnalyticsAction.class); + private final XPackLicenseState licenseState; private final Client client; private final PersistentTasksService persistentTasksService; @@ -107,7 +115,7 @@ protected void masterOperation(StartDataFrameAnalyticsAction.Request request, Cl new ActionListener>() { @Override public void onResponse(PersistentTasksCustomMetaData.PersistentTask task) { - listener.onResponse(new AcknowledgedResponse(true)); + waitForAnalyticsStarted(task, listener); } @Override @@ -138,19 +146,111 @@ public void onFailure(Exception e) { configProvider.get(request.getId(), configListener); } - public static class DataFrameAnalyticsTask extends AllocatedPersistentTask { + private void waitForAnalyticsStarted(PersistentTasksCustomMetaData.PersistentTask task, + ActionListener listener) { + AnalyticsPredicate predicate = new AnalyticsPredicate(); + // TODO Add timeout parameter to the start analytics request and use it here instead of hardcoded value + persistentTasksService.waitForPersistentTaskCondition(task.getId(), predicate, TimeValue.timeValueSeconds(10), + new PersistentTasksService.WaitForPersistentTaskListener() { + + @Override + public void onResponse(PersistentTasksCustomMetaData.PersistentTask persistentTask) { + if (predicate.exception != null) { + // We want to return to the caller without leaving an unassigned persistent task, to match + // what would have happened if the error had been detected in the "fast fail" validation + cancelAnalyticsStart(task, predicate.exception, listener); + } else { + listener.onResponse(new AcknowledgedResponse(true)); + } + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + public void onTimeout(TimeValue timeout) { + listener.onFailure(new ElasticsearchException("Starting data frame analytics [" + task.getParams().getId() + + "] timed out after [" + timeout + "]")); + } + }); + } + + /** + * Important: the methods of this class must NOT throw exceptions. If they did then the callers + * of endpoints waiting for a condition tested by this predicate would never get a response. + */ + private class AnalyticsPredicate implements Predicate> { + + private volatile Exception exception; + + @Override + public boolean test(PersistentTasksCustomMetaData.PersistentTask persistentTask) { + if (persistentTask == null) { + return false; + } + + PersistentTasksCustomMetaData.Assignment assignment = persistentTask.getAssignment(); + if (assignment != null && assignment.equals(PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT) == false && + assignment.isAssigned() == false) { + // Assignment has failed despite passing our "fast fail" validation + exception = new ElasticsearchStatusException("Could not start data frame analytics task, allocation explanation [" + + assignment.getExplanation() + "]", RestStatus.TOO_MANY_REQUESTS); + return true; + } + DataFrameAnalyticsTaskState taskState = (DataFrameAnalyticsTaskState) persistentTask.getState(); + DataFrameAnalyticsState analyticsState = taskState == null ? DataFrameAnalyticsState.STOPPED : taskState.getState(); + return analyticsState == DataFrameAnalyticsState.STARTED; + } + } + + private void cancelAnalyticsStart( + PersistentTasksCustomMetaData.PersistentTask persistentTask, Exception exception, + ActionListener listener) { + persistentTasksService.sendRemoveRequest(persistentTask.getId(), + new ActionListener>() { + @Override + public void onResponse(PersistentTasksCustomMetaData.PersistentTask task) { + // We succeeded in cancelling the persistent task, but the + // problem that caused us to cancel it is the overall result + listener.onFailure(exception); + } + + @Override + public void onFailure(Exception e) { + LOGGER.error("[" + persistentTask.getParams().getId() + "] Failed to cancel persistent task that could " + + "not be assigned due to [" + exception.getMessage() + "]", e); + listener.onFailure(exception); + } + } + ); + } + + public static class DataFrameAnalyticsTask extends AllocatedPersistentTask implements StartDataFrameAnalyticsAction.TaskMatcher { private final StartDataFrameAnalyticsAction.TaskParams taskParams; + @Nullable + private volatile Long reindexingTaskId; public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map headers, StartDataFrameAnalyticsAction.TaskParams taskParams) { - super(id, type, action, "data_frame_analytics-" + taskParams.getId(), parentTask, headers); + super(id, type, action, MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + taskParams.getId(), parentTask, headers); this.taskParams = Objects.requireNonNull(taskParams); } public StartDataFrameAnalyticsAction.TaskParams getParams() { return taskParams; } + + public void setReindexingTaskId(long reindexingTaskId) { + this.reindexingTaskId = reindexingTaskId; + } + + @Nullable + public Long getReindexingTaskId() { + return reindexingTaskId; + } } public static class TaskExecutor extends PersistentTasksExecutor { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 0395f3115bbc3..2ad1de10b76db 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -12,12 +12,14 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; -import org.elasticsearch.client.Client; +import org.elasticsearch.action.support.ContextPreservingActionListener; +import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.MappingMetaData; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.IndexSortConfig; import org.elasticsearch.index.reindex.BulkByScrollResponse; @@ -25,6 +27,7 @@ import org.elasticsearch.index.reindex.ReindexRequest; import org.elasticsearch.script.Script; import org.elasticsearch.search.sort.SortOrder; +import org.elasticsearch.tasks.Task; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; @@ -39,6 +42,9 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; public class DataFrameAnalyticsManager { @@ -55,11 +61,14 @@ public class DataFrameAnalyticsManager { ); private final ClusterService clusterService; - private final Client client; + /** + * We need a {@link NodeClient} to be get the reindexing task and be able to report progress + */ + private final NodeClient client; private final DataFrameAnalyticsConfigProvider configProvider; private final AnalyticsProcessManager processManager; - public DataFrameAnalyticsManager(ClusterService clusterService, Client client, DataFrameAnalyticsConfigProvider configProvider, + public DataFrameAnalyticsManager(ClusterService clusterService, NodeClient client, DataFrameAnalyticsConfigProvider configProvider, AnalyticsProcessManager processManager) { this.clusterService = Objects.requireNonNull(clusterService); this.client = Objects.requireNonNull(client); @@ -117,12 +126,14 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF reindexRequest.setSourceQuery(config.getParsedQuery()); reindexRequest.setDestIndex(config.getDest()); reindexRequest.setScript(new Script("ctx._source." + DataFrameAnalyticsFields.ID + " = ctx._id")); - ClientHelper.executeWithHeadersAsync(config.getHeaders(), - ClientHelper.ML_ORIGIN, - client, - ReindexAction.INSTANCE, - reindexRequest, - reindexCompletedListener); + + final ThreadContext threadContext = client.threadPool().getThreadContext(); + final Supplier supplier = threadContext.newRestorableContext(false); + try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(ML_ORIGIN)) { + Task reindexTask = client.executeLocally(ReindexAction.INSTANCE, reindexRequest, + new ContextPreservingActionListener<>(supplier, reindexCompletedListener)); + task.setReindexingTaskId(reindexTask.getId()); + } }, reindexCompletedListener::onFailure ); @@ -137,7 +148,7 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi DataFrameAnalyticsTaskState analyzingState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.ANALYZING, task.getAllocationId()); task.updatePersistentTaskState(analyzingState, ActionListener.wrap( - updatedTask -> processManager.runJob(config, dataExtractorFactory, + updatedTask -> processManager.runJob(task.getAllocationId(), config, dataExtractorFactory, error -> { if (error != null) { task.markAsFailed(error); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 5868cefe3a30e..2c2c110b76c12 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -8,11 +8,15 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.admin.indices.refresh.RefreshAction; +import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.client.Client; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.env.Environment; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; @@ -25,7 +29,10 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; public class AnalyticsProcessManager { @@ -36,6 +43,7 @@ public class AnalyticsProcessManager { private final Environment environment; private final ThreadPool threadPool; private final AnalyticsProcessFactory processFactory; + private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); public AnalyticsProcessManager(Client client, Environment environment, ThreadPool threadPool, AnalyticsProcessFactory analyticsProcessFactory) { @@ -45,46 +53,51 @@ public AnalyticsProcessManager(Client client, Environment environment, ThreadPoo this.processFactory = Objects.requireNonNull(analyticsProcessFactory); } - public void runJob(DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory, + public void runJob(long taskAllocationId, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory, Consumer finishHandler) { threadPool.generic().execute(() -> { DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); AnalyticsProcess process = createProcess(config.getId(), createProcessConfig(config, dataExtractor)); + processContextByAllocation.putIfAbsent(taskAllocationId, new ProcessContext()); ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); - AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(dataFrameRowsJoiner); + AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(processContextByAllocation.get(taskAllocationId), + dataFrameRowsJoiner); executorService.execute(() -> resultProcessor.process(process)); - executorService.execute(() -> processData(config.getId(), dataExtractor, process, resultProcessor, finishHandler)); + executorService.execute( + () -> processData(taskAllocationId, config, dataExtractor, process, resultProcessor, finishHandler)); }); } - private void processData(String jobId, DataFrameDataExtractor dataExtractor, AnalyticsProcess process, - AnalyticsResultProcessor resultProcessor, Consumer finishHandler) { + private void processData(long taskAllocationId, DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor, + AnalyticsProcess process, AnalyticsResultProcessor resultProcessor, Consumer finishHandler) { try { writeHeaderRecord(dataExtractor, process); writeDataRows(dataExtractor, process); process.writeEndOfDataMessage(); process.flushStream(); - LOGGER.info("[{}] Waiting for result processor to complete", jobId); + LOGGER.info("[{}] Waiting for result processor to complete", config.getId()); resultProcessor.awaitForCompletion(); - LOGGER.info("[{}] Result processor has completed", jobId); + refreshDest(config); + LOGGER.info("[{}] Result processor has completed", config.getId()); } catch (IOException e) { - LOGGER.error(new ParameterizedMessage("[{}] Error writing data to the process", jobId), e); + LOGGER.error(new ParameterizedMessage("[{}] Error writing data to the process", config.getId()), e); // TODO Handle this failure by setting the task state to FAILED } finally { - LOGGER.info("[{}] Closing process", jobId); + LOGGER.info("[{}] Closing process", config.getId()); try { process.close(); - LOGGER.info("[{}] Closed process", jobId); + LOGGER.info("[{}] Closed process", config.getId()); // This results in marking the persistent task as complete finishHandler.accept(null); } catch (IOException e) { - LOGGER.error("[{}] Error closing data frame analyzer process", jobId); + LOGGER.error("[{}] Error closing data frame analyzer process", config.getId()); finishHandler.accept(e); } + processContextByAllocation.remove(taskAllocationId); } } @@ -145,4 +158,24 @@ private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig conf new ByteSizeValue(1, ByteSizeUnit.GB), 1, dataFrameAnalyses.get(0)); return processConfig; } + + @Nullable + public Integer getProgressPercent(long allocationId) { + ProcessContext processContext = processContextByAllocation.get(allocationId); + return processContext == null ? null : processContext.progressPercent.get(); + } + + private void refreshDest(DataFrameAnalyticsConfig config) { + ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, + () -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest())).actionGet()); + } + + static class ProcessContext { + + private final AtomicInteger progressPercent = new AtomicInteger(0); + + void setProgressPercent(int progressPercent) { + this.progressPercent.set(progressPercent); + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java index 4d15bf89b29e9..ced64ab04a280 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResult.java @@ -18,29 +18,41 @@ public class AnalyticsResult implements ToXContentObject { public static final ParseField TYPE = new ParseField("analytics_result"); + public static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), - a -> new AnalyticsResult((RowResults) a[0])); + a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1])); static { PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), PROGRESS_PERCENT); } private final RowResults rowResults; + private final Integer progressPercent; - public AnalyticsResult(RowResults rowResults) { + public AnalyticsResult(RowResults rowResults, Integer progressPercent) { this.rowResults = rowResults; + this.progressPercent = progressPercent; } public RowResults getRowResults() { return rowResults; } + public Integer getProgressPercent() { + return progressPercent; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); if (rowResults != null) { builder.field(RowResults.TYPE.getPreferredName(), rowResults); } + if (progressPercent != null) { + builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent); + } builder.endObject(); return builder; } @@ -55,11 +67,11 @@ public boolean equals(Object other) { } AnalyticsResult that = (AnalyticsResult) other; - return Objects.equals(rowResults, that.rowResults); + return Objects.equals(rowResults, that.rowResults) && Objects.equals(progressPercent, that.progressPercent); } @Override public int hashCode() { - return Objects.hash(rowResults); + return Objects.hash(rowResults, progressPercent); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 1b70d68598df6..b2721daf1e515 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -18,10 +18,12 @@ public class AnalyticsResultProcessor { private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); + private final AnalyticsProcessManager.ProcessContext processContext; private final DataFrameRowsJoiner dataFrameRowsJoiner; private final CountDownLatch completionLatch = new CountDownLatch(1); - public AnalyticsResultProcessor(DataFrameRowsJoiner dataFrameRowsJoiner) { + public AnalyticsResultProcessor(AnalyticsProcessManager.ProcessContext processContext, DataFrameRowsJoiner dataFrameRowsJoiner) { + this.processContext = Objects.requireNonNull(processContext); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); } @@ -57,5 +59,9 @@ private void processResult(AnalyticsResult result) { if (rowResults != null) { dataFrameRowsJoiner.processRowResults(rowResults); } + Integer progressPercent = result.getProgressPercent(); + if (progressPercent != null) { + processContext.setProgressPercent(progressPercent); + } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index cded955767344..716b96a615846 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -41,7 +41,7 @@ public void testProcess_GivenNoResults() { } public void testProcess_GivenEmptyResults() { - givenProcessResults(Arrays.asList(new AnalyticsResult(null), new AnalyticsResult(null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, null), new AnalyticsResult(null, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -53,7 +53,7 @@ public void testProcess_GivenEmptyResults() { public void testProcess_GivenRowResults() { RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1), new AnalyticsResult(rowResults2))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null), new AnalyticsResult(rowResults2, null))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -69,6 +69,6 @@ private void givenProcessResults(List results) { } private AnalyticsResultProcessor createResultProcessor() { - return new AnalyticsResultProcessor(dataFrameRowsJoiner); + return new AnalyticsResultProcessor(new AnalyticsProcessManager.ProcessContext(), dataFrameRowsJoiner); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java index d0d3b4ee5f99c..22c03d47682e8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultTests.java @@ -17,10 +17,14 @@ public class AnalyticsResultTests extends AbstractXContentTestCase Date: Tue, 19 Feb 2019 09:34:30 -0600 Subject: [PATCH 21/67] [FEATURE][ML] Add ability to include/exclude fields to/from analyses (#38797) * [ML-DATA-FRAME] Add ability to include/exclude fields to/from analysis * s/fields/analysis_fields/g --- .../dataframe/DataFrameAnalyticsConfig.java | 37 ++++++++++-- .../xpack/core/ml/job/messages/Messages.java | 2 + .../persistence/ElasticsearchMappings.java | 3 + .../ml/job/results/ReservedFieldNames.java | 1 + .../DataFrameAnalyticsConfigTests.java | 6 ++ .../DataFrameDataExtractorFactory.java | 49 ++++++++++++++-- .../DataFrameDataExtractorFactoryTests.java | 56 ++++++++++++++++--- .../test/ml/data_frame_analytics_crud.yml | 6 +- 8 files changed, 142 insertions(+), 18 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index 8bc3202aa319b..9948b899e2907 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; @@ -35,6 +36,8 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING; + public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { private static final Logger logger = LogManager.getLogger(DataFrameAnalyticsConfig.class); @@ -64,6 +67,7 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { public static final ParseField ANALYSES = new ParseField("analyses"); public static final ParseField CONFIG_TYPE = new ParseField("config_type"); public static final ParseField QUERY = new ParseField("query"); + public static final ParseField ANALYSES_FIELDS = new ParseField("analyses_fields"); public static final ParseField HEADERS = new ParseField("headers"); public static final ObjectParser STRICT_PARSER = createParser(false); @@ -78,6 +82,10 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel parser.declareString(Builder::setDest, DEST); parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES); parser.declareObject((builder, query) -> builder.setQuery(query, ignoreUnknownFields), (p, c) -> p.mapOrdered(), QUERY); + parser.declareField(Builder::setAnalysesFields, + (p, c) -> FetchSourceContext.fromXContent(p), + ANALYSES_FIELDS, + OBJECT_ARRAY_BOOLEAN_OR_STRING); if (ignoreUnknownFields) { // Headers are not parsed by the strict (config) parser, so headers supplied in the _body_ of a REST request will be rejected. // (For config, headers are explicitly transferred from the auth headers by code in the put data frame actions.) @@ -92,10 +100,11 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel private final List analyses; private final Map query; private final CachedSupplier querySupplier; + private final FetchSourceContext analysesFields; private final Map headers; public DataFrameAnalyticsConfig(String id, String source, String dest, List analyses, - Map query, Map headers) { + Map query, Map headers, FetchSourceContext analysesFields) { this.id = ExceptionsHelper.requireNonNull(id, ID); this.source = ExceptionsHelper.requireNonNull(source, SOURCE); this.dest = ExceptionsHelper.requireNonNull(dest, DEST); @@ -109,6 +118,7 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); + this.analysesFields = analysesFields; this.headers = Collections.unmodifiableMap(headers); } @@ -119,6 +129,7 @@ public DataFrameAnalyticsConfig(StreamInput in) throws IOException { analyses = in.readList(DataFrameAnalysisConfig::new); this.query = in.readMap(); this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); + this.analysesFields = in.readOptionalWriteable(FetchSourceContext::new); this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); } @@ -148,6 +159,10 @@ public QueryBuilder getParsedQuery() { return querySupplier.get(); } + public FetchSourceContext getAnalysesFields() { + return analysesFields; + } + /** * Calls the lazy parser and returns any gathered deprecations * @return The deprecations from parsing the query @@ -177,6 +192,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(CONFIG_TYPE.getPreferredName(), TYPE); } builder.field(QUERY.getPreferredName(), query); + if (analysesFields != null) { + builder.field(ANALYSES_FIELDS.getPreferredName(), analysesFields); + } if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { builder.field(HEADERS.getPreferredName(), headers); } @@ -191,6 +209,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(dest); out.writeList(analyses); out.writeMap(query); + out.writeOptionalWriteable(analysesFields); out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); } @@ -204,13 +223,14 @@ public boolean equals(Object o) { && Objects.equals(source, other.source) && Objects.equals(dest, other.dest) && Objects.equals(analyses, other.analyses) + && Objects.equals(query, other.query) && Objects.equals(headers, other.headers) - && Objects.equals(query, other.query); + && Objects.equals(analysesFields, other.analysesFields); } @Override public int hashCode() { - return Objects.hash(id, source, dest, analyses, query, headers); + return Objects.hash(id, source, dest, analyses, query, headers, analysesFields); } public static String documentId(String id) { @@ -224,6 +244,7 @@ public static class Builder { private String dest; private List analyses; private Map query = Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()); + private FetchSourceContext analysesFields; private Map headers = Collections.emptyMap(); public Builder() {} @@ -243,6 +264,9 @@ public Builder(DataFrameAnalyticsConfig config) { this.analyses = new ArrayList<>(config.analyses); this.query = new LinkedHashMap<>(config.query); this.headers = new HashMap<>(config.headers); + if (config.analysesFields != null) { + this.analysesFields = new FetchSourceContext(true, config.analysesFields.includes(), config.analysesFields.excludes()); + } } public Builder setId(String id) { @@ -284,13 +308,18 @@ public Builder setQuery(Map query, boolean lenient) { return this; } + public Builder setAnalysesFields(FetchSourceContext fields) { + this.analysesFields = fields; + return this; + } + public Builder setHeaders(Map headers) { this.headers = headers; return this; } public DataFrameAnalyticsConfig build() { - return new DataFrameAnalyticsConfig(id, source, dest, analyses, query, headers); + return new DataFrameAnalyticsConfig(id, source, dest, analyses, query, headers, analysesFields); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 4bd3ccbc985d7..6083bf3608647 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -51,6 +51,8 @@ public final class Messages { public static final String DATAFEED_ID_ALREADY_TAKEN = "A datafeed with id [{0}] already exists"; public static final String DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT = "Data Frame Analytics config [{0}] query is not parsable"; + public static final String DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER = + "No compatible fields could be detected in index [{0}] with name [{1}]"; public static final String FILTER_CANNOT_DELETE = "Cannot delete filter [{0}] currently used by jobs {1}"; public static final String FILTER_CONTAINS_TOO_MANY_ITEMS = "Filter [{0}] contains too many items; up to [{1}] items are allowed"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index 6f60a549ae707..fe21147d20be8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -398,6 +398,9 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(DataFrameAnalyticsConfig.DEST.getPreferredName()) .field(TYPE, KEYWORD) .endObject() + .startObject(DataFrameAnalyticsConfig.ANALYSES_FIELDS.getPreferredName()) + .field(ENABLED, false) + .endObject() .startObject(DataFrameAnalyticsConfig.ANALYSES.getPreferredName()) .startObject(PROPERTIES) .startObject("outlier_detection") diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 62e9f78e1c826..378f5da401d2b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -261,6 +261,7 @@ public final class ReservedFieldNames { DataFrameAnalyticsConfig.SOURCE.getPreferredName(), DataFrameAnalyticsConfig.DEST.getPreferredName(), DataFrameAnalyticsConfig.ANALYSES.getPreferredName(), + DataFrameAnalyticsConfig.ANALYSES_FIELDS.getPreferredName(), "outlier_detection", "method", "number_neighbours", diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index 987b0deceb77d..478b194ed5af6 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; @@ -87,6 +88,11 @@ public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) { Collections.singletonMap(TermQueryBuilder.NAME, Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))), true); } + if (randomBoolean()) { + builder.setAnalysesFields(new FetchSourceContext(true, + generateRandomStringArray(10, 10, false, false), + generateRandomStringArray(10, 10, false, false))); + } return builder; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index 2f0db834d8578..f7334d29088df 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -12,12 +12,17 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.regex.Regex; import org.elasticsearch.index.IndexNotFoundException; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NameResolver; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; @@ -94,8 +99,7 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { public static void create(Client client, DataFrameAnalyticsConfig config, ActionListener listener) { - - validateIndexAndExtractFields(client, config.getHeaders(), config.getDest(), ActionListener.wrap( + validateIndexAndExtractFields(client, config.getHeaders(), config.getDest(), config.getAnalysesFields(), ActionListener.wrap( extractedFields -> listener.onResponse( new DataFrameDataExtractorFactory(client, config.getId(), config.getDest(), extractedFields, config.getHeaders())), listener::onFailure @@ -112,7 +116,7 @@ public static void create(Client client, public static void validateConfigAndSourceIndex(Client client, DataFrameAnalyticsConfig config, ActionListener listener) { - validateIndexAndExtractFields(client, config.getHeaders(), config.getSource(), ActionListener.wrap( + validateIndexAndExtractFields(client, config.getHeaders(), config.getSource(), config.getAnalysesFields(), ActionListener.wrap( fields -> { config.getParsedQuery(); // validate query is acceptable listener.onResponse(true); @@ -122,10 +126,13 @@ public static void validateConfigAndSourceIndex(Client client, } // Visible for testing - static ExtractedFields detectExtractedFields(String index, FieldCapabilitiesResponse fieldCapabilitiesResponse) { + static ExtractedFields detectExtractedFields(String index, + FetchSourceContext desiredFields, + FieldCapabilitiesResponse fieldCapabilitiesResponse) { Set fields = fieldCapabilitiesResponse.get().keySet(); fields.removeAll(IGNORE_FIELDS); removeFieldsWithIncompatibleTypes(fields, fieldCapabilitiesResponse); + includeAndExcludeFields(fields, desiredFields, index); List sortedFields = new ArrayList<>(fields); // We sort the fields to ensure the checksum for each document is deterministic Collections.sort(sortedFields); @@ -148,13 +155,45 @@ private static void removeFieldsWithIncompatibleTypes(Set fields, FieldC } } + private static void includeAndExcludeFields(Set fields, FetchSourceContext desiredFields, String index) { + if (desiredFields == null) { + return; + } + String includes = desiredFields.includes().length == 0 ? "*" : Strings.arrayToCommaDelimitedString(desiredFields.includes()); + String excludes = Strings.arrayToCommaDelimitedString(desiredFields.excludes()); + + if (Regex.isMatchAllPattern(includes) && excludes.isEmpty()) { + return; + } + try { + // If the inclusion set does not match anything, that means the user's desired fields cannot be found in + // the collection of supported field types. We should let the user know. + Set includedSet = NameResolver.newUnaliased(fields, + (ex) -> new ResourceNotFoundException(Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER, index, ex))) + .expand(includes, false); + // If the exclusion set does not match anything, that means the fields are already not present + // no need to raise if nothing matched + Set excludedSet = NameResolver.newUnaliased(fields, + (ex) -> new ResourceNotFoundException(Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER, index, ex))) + .expand(excludes, true); + + fields.retainAll(includedSet); + fields.removeAll(excludedSet); + } catch (ResourceNotFoundException ex) { + // Re-wrap our exception so that we throw the same exception type when there are no fields. + throw ExceptionsHelper.badRequestException(ex.getMessage()); + } + + } + private static void validateIndexAndExtractFields(Client client, Map headers, String index, + FetchSourceContext desiredFields, ActionListener listener) { // Step 2. Extract fields (if possible) and notify listener ActionListener fieldCapabilitiesHandler = ActionListener.wrap( - fieldCapabilitiesResponse -> listener.onResponse(detectExtractedFields(index, fieldCapabilitiesResponse)), + fieldCapabilitiesResponse -> listener.onResponse(detectExtractedFields(index, desiredFields, fieldCapabilitiesResponse)), e -> { if (e instanceof IndexNotFoundException) { listener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java index 9aac76da91e95..f823c0c8c1e1f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java @@ -8,11 +8,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.fieldcaps.FieldCapabilities; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -27,12 +29,13 @@ public class DataFrameDataExtractorFactoryTests extends ESTestCase { private static final String INDEX = "source_index"; + private static final FetchSourceContext EMPTY_CONTEXT = new FetchSourceContext(true, new String[0], new String[0]); public void testDetectExtractedFields_GivenFloatField() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float").build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); @@ -44,7 +47,7 @@ public void testDetectExtractedFields_GivenNumericFieldWithMultipleTypes() { .addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float") .build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); @@ -56,7 +59,7 @@ public void testDetectExtractedFields_GivenNonNumericField() { .addAggregatableField("some_keyword", "keyword").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities)); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities)); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -65,7 +68,7 @@ public void testDetectExtractedFields_GivenFieldWithNumericAndNonNumericTypes() .addAggregatableField("indecisive_field", "float", "keyword").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities)); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities)); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -76,7 +79,7 @@ public void testDetectExtractedFields_GivenMultipleFields() { .addAggregatableField("some_keyword", "keyword") .build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(2)); @@ -89,7 +92,7 @@ public void testDetectExtractedFields_GivenIgnoredField() { .addAggregatableField("_id", "float").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities)); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities)); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -108,13 +111,52 @@ public void testDetectExtractedFields_ShouldSortFieldsAlphabetically() { } FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, fieldCapabilities); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, equalTo(sortedFields)); } + public void testDetectedExtractedFields_GivenIncludeWithMissingField() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("my_field1", "float") + .addAggregatableField("my_field2", "float") + .build(); + + FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, fieldCapabilities)); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index] with name [your_field1]")); + } + + public void testDetectedExtractedFields_GivenExcludeAllValidFields() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("my_field1", "float") + .addAggregatableField("my_field2", "float") + .build(); + + FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"}); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, fieldCapabilities)); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); + } + + public void testDetectedExtractedFields_GivenInclusionsAndExclusions() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("my_field1_nope", "float") + .addAggregatableField("my_field1", "float") + .addAggregatableField("your_field2", "float") + .addAggregatableField("your_keyword", "keyword") + .build(); + + FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, fieldCapabilities); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2"))); + } + private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index a8a1c10de0164..da6c27533c005 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -26,7 +26,7 @@ - match: { data_frame_analytics: [] } --- -"Test put valid config with default outlier detection and query": +"Test put valid config with default outlier detection, query, and filter": - do: ml.put_data_frame_analytics: @@ -36,13 +36,15 @@ "source": "index-source", "dest": "index-dest", "analyses": [{"outlier_detection":{}}], - "query": {"term" : { "user" : "Kimchy" }} + "query": {"term" : { "user" : "Kimchy" }}, + "analyses_fields": [ "obj1.*", "obj2.*" ] } - match: { id: "simple-outlier-detection-with-query" } - match: { source: "index-source" } - match: { dest: "index-dest" } - match: { analyses: [{"outlier_detection":{}}] } - match: { query: {"term" : { "user" : "Kimchy"} } } + - match: { analyses_fields: {"includes" : ["obj1.*", "obj2.*" ], "excludes": [] } } --- "Test put config with security headers in the body": From 90eaf94327b64eb5fc746c8849b9d33b948716dc Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 20 Feb 2019 16:12:28 +0200 Subject: [PATCH 22/67] [FEATURE][ML] Add timeout to the start DF analytics request (#39177) --- .../action/StartDataFrameAnalyticsAction.java | 47 +++++++++++++++++-- .../StartDataFrameAnalyticsRequestTests.java | 5 ++ .../ml/qa/ml-with-security/build.gradle | 1 + ...ransportStartDataFrameAnalyticsAction.java | 9 ++-- .../RestStartDataFrameAnalyticsAction.java | 14 +++++- .../api/ml.start_data_frame_analytics.json | 10 ++++ .../test/ml/start_data_frame_analytics.yml | 12 +++++ 7 files changed, 88 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java index 2a7965261856f..aec69e3f4e9a3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java @@ -13,10 +13,13 @@ import org.elasticsearch.action.support.master.MasterNodeRequest; import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -24,6 +27,7 @@ import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -45,23 +49,55 @@ public AcknowledgedResponse newResponse() { public static class Request extends MasterNodeRequest implements ToXContentObject { + public static final ParseField TIMEOUT = new ParseField("timeout"); + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); + + static { + PARSER.declareString((request, id) -> request.id = id, DataFrameAnalyticsConfig.ID); + PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); + } + + public static Request parseRequest(String id, XContentParser parser) { + Request request = PARSER.apply(parser, null); + if (request.getId() == null) { + request.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(request.getId())) { + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + request.getId(), id)); + } + return request; + } + private String id; + private TimeValue timeout = TimeValue.timeValueSeconds(20); public Request(String id) { - this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + setId(id); } public Request(StreamInput in) throws IOException { readFrom(in); } - public Request() { + public Request() {} + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); } public String getId() { return id; } + public void setTimeout(TimeValue timeout) { + this.timeout = timeout; + } + + public TimeValue getTimeout() { + return timeout; + } + @Override public ActionRequestValidationException validate() { return null; @@ -71,12 +107,14 @@ public ActionRequestValidationException validate() { public void readFrom(StreamInput in) throws IOException { super.readFrom(in); id = in.readString(); + timeout = in.readTimeValue(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(id); + out.writeTimeValue(timeout); } @Override @@ -84,12 +122,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (id != null) { builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); } + builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep()); return builder; } @Override public int hashCode() { - return Objects.hash(id); + return Objects.hash(id, timeout); } @Override @@ -101,7 +140,7 @@ public boolean equals(Object obj) { return false; } StartDataFrameAnalyticsAction.Request other = (StartDataFrameAnalyticsAction.Request) obj; - return Objects.equals(id, other.id); + return Objects.equals(id, other.id) && Objects.equals(timeout, other.timeout); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java index 9875c87f5ef3c..a7025976134d7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.Request; @@ -13,6 +14,10 @@ public class StartDataFrameAnalyticsRequestTests extends AbstractWireSerializing @Override protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + if (randomBoolean()) { + request.setTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } return new Request(randomAlphaOfLength(20)); } diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 2003f3c18aa2e..1693b92d1e0c1 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -99,6 +99,7 @@ integTestRunner { 'ml/start_data_frame_analytics/Test start given missing config', 'ml/start_data_frame_analytics/Test start given missing source index', 'ml/start_data_frame_analytics/Test start given source index has no compatible fields', + 'ml/start_data_frame_analytics/Test start with inconsistent body/param ids', 'ml/start_stop_datafeed/Test start datafeed job, but not open', 'ml/start_stop_datafeed/Test start non existing datafeed', 'ml/start_stop_datafeed/Test stop non existing datafeed', diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index cb7e84651ac73..02037b65407fa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -115,7 +115,7 @@ protected void masterOperation(StartDataFrameAnalyticsAction.Request request, Cl new ActionListener>() { @Override public void onResponse(PersistentTasksCustomMetaData.PersistentTask task) { - waitForAnalyticsStarted(task, listener); + waitForAnalyticsStarted(task, request.getTimeout(), listener); } @Override @@ -147,10 +147,10 @@ public void onFailure(Exception e) { } private void waitForAnalyticsStarted(PersistentTasksCustomMetaData.PersistentTask task, - ActionListener listener) { + TimeValue timeout, ActionListener listener) { AnalyticsPredicate predicate = new AnalyticsPredicate(); - // TODO Add timeout parameter to the start analytics request and use it here instead of hardcoded value - persistentTasksService.waitForPersistentTaskCondition(task.getId(), predicate, TimeValue.timeValueSeconds(10), + persistentTasksService.waitForPersistentTaskCondition(task.getId(), predicate, timeout, + new PersistentTasksService.WaitForPersistentTaskListener() { @Override @@ -280,6 +280,7 @@ protected DiscoveryNode selectLeastLoadedNode(ClusterState clusterState, Predica @Override protected void nodeOperation(AllocatedPersistentTask task, StartDataFrameAnalyticsAction.TaskParams params, PersistentTaskState state) { + LOGGER.info("[{}] Starting data frame analytics", params.getId()); DataFrameAnalyticsTaskState startedState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.STARTED, task.getAllocationId()); task.updatePersistentTaskState(startedState, ActionListener.wrap( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStartDataFrameAnalyticsAction.java index 035822429cc29..7502f31375f1a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStartDataFrameAnalyticsAction.java @@ -7,6 +7,7 @@ import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestRequest; @@ -33,8 +34,17 @@ public String getName() { @Override protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { String id = restRequest.param(DataFrameAnalyticsConfig.ID.getPreferredName()); - StartDataFrameAnalyticsAction.Request request = new StartDataFrameAnalyticsAction.Request(id); - + StartDataFrameAnalyticsAction.Request request; + if (restRequest.hasContentOrSourceParam()) { + request = StartDataFrameAnalyticsAction.Request.parseRequest(id, restRequest.contentOrSourceParamParser()); + } else { + request = new StartDataFrameAnalyticsAction.Request(id); + if (restRequest.hasParam(StartDataFrameAnalyticsAction.Request.TIMEOUT.getPreferredName())) { + TimeValue timeout = restRequest.paramAsTime(StartDataFrameAnalyticsAction.Request.TIMEOUT.getPreferredName(), + request.getTimeout()); + request.setTimeout(timeout); + } + } return channel -> client.execute(StartDataFrameAnalyticsAction.INSTANCE, request, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json index b4e61b3fab125..dfe0cac2f7b67 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json @@ -10,7 +10,17 @@ "required": true, "description": "The ID of the data frame analytics to start" } + }, + "params": { + "timeout": { + "type": "time", + "required": false, + "description": "Controls the time to wait until the task has started. Defaults to 20 seconds" + } } + }, + "body": { + "description": "The start data frame analytics parameters" } } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml index 36e50b0229737..5042f80c2b7b8 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml @@ -44,3 +44,15 @@ catch: /No compatible fields could be detected in index \[empty-index\]/ ml.start_data_frame_analytics: id: "foo" + +--- +"Test start with inconsistent body/param ids": + + - do: + catch: /Inconsistent id; 'body_id' specified in the body differs from 'url_id' specified as a URL argument/ + ml.start_data_frame_analytics: + id: "url_id" + body: > + { + "id": "body_id" + } From 9faeb94bf2276de7a4dd2dc094f4e05d5f8de971 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 27 Feb 2019 15:50:26 -0600 Subject: [PATCH 23/67] [DATA-FRAME-ANALYTICS] Add task recovery on node change (#39416) * [DATA-FRAME-ANALYTICS] Add task recovery on node change * Adding tests for recovery given task state * Addressing PR comments --- x-pack/plugin/ml/build.gradle | 2 + ...ransportStartDataFrameAnalyticsAction.java | 23 +- .../dataframe/DataFrameAnalyticsManager.java | 55 ++- .../dataframe/analyses/DataFrameAnalysis.java | 8 + .../dataframe/analyses/OutlierDetection.java | 15 + .../DataFrameDataExtractorFactory.java | 38 +- .../plugin-security-test.policy | 5 + .../DataFrameDataExtractorFactoryTests.java | 43 +- .../DataFrameAnalyticsManagerIT.java | 410 ++++++++++++++++++ .../xpack/ml/support/BaseMlIntegTestCase.java | 21 + 10 files changed, 584 insertions(+), 36 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/plugin-metadata/plugin-security-test.policy create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java diff --git a/x-pack/plugin/ml/build.gradle b/x-pack/plugin/ml/build.gradle index 0fe0af236f9ad..807a4ed1aadbc 100644 --- a/x-pack/plugin/ml/build.gradle +++ b/x-pack/plugin/ml/build.gradle @@ -42,6 +42,7 @@ dependencies { compileOnly project(path: xpackModule('core'), configuration: 'default') compileOnly "org.elasticsearch.plugin:elasticsearch-scripting-painless-spi:${versions.elasticsearch}" testCompile project(path: xpackModule('core'), configuration: 'testArtifacts') + testCompile project(':modules:lang-painless') // This should not be here testCompile project(path: xpackModule('security'), configuration: 'testArtifacts') @@ -100,6 +101,7 @@ task internalClusterTest(type: RandomizedTestingTask, dependsOn: unitTest.dependsOn) { include '**/*IT.class' systemProperty 'es.set.netty.runtime.available.processors', 'false' + systemProperties 'java.security.policy': file("src/main/plugin-metadata/plugin-security-test.policy").absolutePath } check.dependsOn internalClusterTest internalClusterTest.mustRunAfter test diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 02037b65407fa..8e9029c38c46f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -281,12 +281,23 @@ protected DiscoveryNode selectLeastLoadedNode(ClusterState clusterState, Predica protected void nodeOperation(AllocatedPersistentTask task, StartDataFrameAnalyticsAction.TaskParams params, PersistentTaskState state) { LOGGER.info("[{}] Starting data frame analytics", params.getId()); - DataFrameAnalyticsTaskState startedState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.STARTED, - task.getAllocationId()); - task.updatePersistentTaskState(startedState, ActionListener.wrap( - response -> manager.execute((DataFrameAnalyticsTask) task), - task::markAsFailed - )); + DataFrameAnalyticsTaskState analyticsTaskState = (DataFrameAnalyticsTaskState) state; + + // If we are "stopping" there is nothing to do + if (analyticsTaskState != null && analyticsTaskState.getState() == DataFrameAnalyticsState.STOPPING) { + return; + } + + if (analyticsTaskState == null) { + DataFrameAnalyticsTaskState startedState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.STARTED, + task.getAllocationId()); + task.updatePersistentTaskState(startedState, ActionListener.wrap( + response -> manager.execute((DataFrameAnalyticsTask) task, DataFrameAnalyticsState.STARTED), + task::markAsFailed)); + } else { + manager.execute((DataFrameAnalyticsTask)task, analyticsTaskState.getState()); + } + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 2ad1de10b76db..3f2f568c8d419 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -9,6 +9,8 @@ import org.elasticsearch.action.admin.indices.create.CreateIndexAction; import org.elasticsearch.action.admin.indices.create.CreateIndexRequest; import org.elasticsearch.action.admin.indices.create.CreateIndexResponse; +import org.elasticsearch.action.admin.indices.delete.DeleteIndexAction; +import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.admin.indices.refresh.RefreshResponse; @@ -32,6 +34,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; @@ -76,21 +79,55 @@ public DataFrameAnalyticsManager(ClusterService clusterService, NodeClient clien this.processManager = Objects.requireNonNull(processManager); } - public void execute(DataFrameAnalyticsTask task) { + public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState currentState) { ActionListener reindexingStateListener = ActionListener.wrap( config -> reindexDataframeAndStartAnalysis(task, config), - e -> task.markAsFailed(e) + task::markAsFailed ); - // Update task state to REINDEXING + // With config in hand, determine action to take ActionListener configListener = ActionListener.wrap( config -> { DataFrameAnalyticsTaskState reindexingState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.REINDEXING, task.getAllocationId()); - task.updatePersistentTaskState(reindexingState, ActionListener.wrap( - updatedTask -> reindexingStateListener.onResponse(config), - reindexingStateListener::onFailure - )); + switch(currentState) { + // If we are STARTED, we are right at the beginning of our task, we should indicate that we are entering the + // REINDEX state and start reindexing. + case STARTED: + task.updatePersistentTaskState(reindexingState, ActionListener.wrap( + updatedTask -> reindexingStateListener.onResponse(config), + reindexingStateListener::onFailure)); + break; + // The task has fully reindexed the documents and we should continue on with our analyses + case ANALYZING: + // TODO apply previously stored model state if applicable + startAnalytics(task, config); + break; + // If we are already at REINDEXING, we are not 100% sure if we reindexed ALL the docs. + // We will delete the destination index, recreate, reindex + case REINDEXING: + ClientHelper.executeAsyncWithOrigin(client, + ML_ORIGIN, + DeleteIndexAction.INSTANCE, + new DeleteIndexRequest(config.getDest()), + ActionListener.wrap( + r-> reindexingStateListener.onResponse(config), + e -> { + if (e instanceof IndexNotFoundException) { + reindexingStateListener.onResponse(config); + } else { + reindexingStateListener.onFailure(e); + } + } + )); + break; + default: + reindexingStateListener.onFailure( + ExceptionsHelper.conflictStatusException( + "Cannot execute analytics task [{}] as it is currently in state [{}]. " + + "Must be one of [STARTED, REINDEXING, ANALYZING]", config.getId(), currentState)); + } + }, reindexingStateListener::onFailure ); @@ -114,7 +151,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF RefreshAction.INSTANCE, new RefreshRequest(config.getDest()), refreshListener), - e -> task.markAsFailed(e) + task::markAsFailed ); // Reindex @@ -159,7 +196,7 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi task::markAsFailed )); }, - e -> task.markAsFailed(e) + task::markAsFailed ); // TODO This could fail with errors. In that case we get stuck with the copied index. diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java index 9fdd093fa324e..03139ac8c9edc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java @@ -10,6 +10,7 @@ import java.util.Locale; import java.util.Map; +import java.util.Set; public interface DataFrameAnalysis extends ToXContentObject { @@ -32,6 +33,13 @@ public String toString() { Type getType(); + /** + * The fields that will contain the results of the analysis + * + * @return Set of Strings representing the result fields for the constructed analysis + */ + Set getResultFields(); + interface Factory { /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java index 47f614ba658f6..3b93f373546a2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java @@ -5,9 +5,12 @@ */ package org.elasticsearch.xpack.ml.dataframe.analyses; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashSet; import java.util.Locale; import java.util.Map; +import java.util.Set; public class OutlierDetection extends AbstractDataFrameAnalysis { @@ -27,6 +30,13 @@ public String toString() { public static final String NUMBER_NEIGHBOURS = "number_neighbours"; public static final String METHOD = "method"; + private static final Set RESULT_FIELDS; + static { + Set set = new LinkedHashSet<>(); + set.add("outlier_score"); + RESULT_FIELDS = Collections.unmodifiableSet(set); + } + private final Integer numberNeighbours; private final Method method; @@ -52,6 +62,11 @@ protected Map getParams() { return params; } + @Override + public Set getResultFields() { + return RESULT_FIELDS; + } + static class Factory implements DataFrameAnalysis.Factory { @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index f7334d29088df..d8c7b44d496ad 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -25,6 +25,8 @@ import org.elasticsearch.xpack.core.ml.utils.NameResolver; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; +import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; +import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysis; import java.util.ArrayList; import java.util.Arrays; @@ -99,10 +101,12 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { public static void create(Client client, DataFrameAnalyticsConfig config, ActionListener listener) { - validateIndexAndExtractFields(client, config.getHeaders(), config.getDest(), config.getAnalysesFields(), ActionListener.wrap( - extractedFields -> listener.onResponse( - new DataFrameDataExtractorFactory(client, config.getId(), config.getDest(), extractedFields, config.getHeaders())), - listener::onFailure + Set resultFields = resolveResultsFields(config); + validateIndexAndExtractFields(client, config.getHeaders(), config.getDest(), config.getAnalysesFields(), resultFields, + ActionListener.wrap( + extractedFields -> listener.onResponse( + new DataFrameDataExtractorFactory(client, config.getId(), config.getDest(), extractedFields, config.getHeaders())), + listener::onFailure )); } @@ -116,21 +120,26 @@ public static void create(Client client, public static void validateConfigAndSourceIndex(Client client, DataFrameAnalyticsConfig config, ActionListener listener) { - validateIndexAndExtractFields(client, config.getHeaders(), config.getSource(), config.getAnalysesFields(), ActionListener.wrap( - fields -> { - config.getParsedQuery(); // validate query is acceptable - listener.onResponse(true); - }, - listener::onFailure + Set resultFields = resolveResultsFields(config); + validateIndexAndExtractFields(client, config.getHeaders(), config.getSource(), config.getAnalysesFields(), resultFields, + ActionListener.wrap( + fields -> { + config.getParsedQuery(); // validate query is acceptable + listener.onResponse(true); + }, + listener::onFailure )); } // Visible for testing static ExtractedFields detectExtractedFields(String index, FetchSourceContext desiredFields, + Set resultFields, FieldCapabilitiesResponse fieldCapabilitiesResponse) { Set fields = fieldCapabilitiesResponse.get().keySet(); fields.removeAll(IGNORE_FIELDS); + // TODO a better solution may be to have some sort of known prefix and filtering that + fields.removeAll(resultFields); removeFieldsWithIncompatibleTypes(fields, fieldCapabilitiesResponse); includeAndExcludeFields(fields, desiredFields, index); List sortedFields = new ArrayList<>(fields); @@ -190,10 +199,12 @@ private static void validateIndexAndExtractFields(Client client, Map headers, String index, FetchSourceContext desiredFields, + Set resultFields, ActionListener listener) { // Step 2. Extract fields (if possible) and notify listener ActionListener fieldCapabilitiesHandler = ActionListener.wrap( - fieldCapabilitiesResponse -> listener.onResponse(detectExtractedFields(index, desiredFields, fieldCapabilitiesResponse)), + fieldCapabilitiesResponse -> listener.onResponse( + detectExtractedFields(index, desiredFields, resultFields, fieldCapabilitiesResponse)), e -> { if (e instanceof IndexNotFoundException) { listener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " @@ -214,4 +225,9 @@ private static void validateIndexAndExtractFields(Client client, return null; }); } + + private static Set resolveResultsFields(DataFrameAnalyticsConfig config) { + List analyses = DataFrameAnalysesUtils.readAnalyses(config.getAnalyses()); + return analyses.stream().flatMap(analysis -> analysis.getResultFields().stream()).collect(Collectors.toSet()); + } } diff --git a/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security-test.policy b/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security-test.policy new file mode 100644 index 0000000000000..d090016eac620 --- /dev/null +++ b/x-pack/plugin/ml/src/main/plugin-metadata/plugin-security-test.policy @@ -0,0 +1,5 @@ +// Needed for painless script to run +grant { + // needed to create the classloader which allows plugins to extend other plugins + permission java.lang.RuntimePermission "createClassLoader"; +}; \ No newline at end of file diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java index f823c0c8c1e1f..affc9cb0b0ca8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java @@ -19,6 +19,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -30,12 +31,14 @@ public class DataFrameDataExtractorFactoryTests extends ESTestCase { private static final String INDEX = "source_index"; private static final FetchSourceContext EMPTY_CONTEXT = new FetchSourceContext(true, new String[0], new String[0]); + private static final Set EMPTY_RESULT_SET = Collections.emptySet(); public void testDetectExtractedFields_GivenFloatField() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float").build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities); + ExtractedFields extractedFields = + DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); @@ -47,7 +50,8 @@ public void testDetectExtractedFields_GivenNumericFieldWithMultipleTypes() { .addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float") .build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities); + ExtractedFields extractedFields = + DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); @@ -59,7 +63,7 @@ public void testDetectExtractedFields_GivenNonNumericField() { .addAggregatableField("some_keyword", "keyword").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities)); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities)); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -68,7 +72,7 @@ public void testDetectExtractedFields_GivenFieldWithNumericAndNonNumericTypes() .addAggregatableField("indecisive_field", "float", "keyword").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities)); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities)); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -79,7 +83,8 @@ public void testDetectExtractedFields_GivenMultipleFields() { .addAggregatableField("some_keyword", "keyword") .build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities); + ExtractedFields extractedFields = + DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(2)); @@ -92,7 +97,7 @@ public void testDetectExtractedFields_GivenIgnoredField() { .addAggregatableField("_id", "float").build(); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities)); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities)); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -111,7 +116,8 @@ public void testDetectExtractedFields_ShouldSortFieldsAlphabetically() { } FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, fieldCapabilities); + ExtractedFields extractedFields = + DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); @@ -126,7 +132,7 @@ public void testDetectedExtractedFields_GivenIncludeWithMissingField() { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, fieldCapabilities)); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, EMPTY_RESULT_SET, fieldCapabilities)); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index] with name [your_field1]")); } @@ -138,7 +144,7 @@ public void testDetectedExtractedFields_GivenExcludeAllValidFields() { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"}); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, fieldCapabilities)); + () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, EMPTY_RESULT_SET, fieldCapabilities)); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -151,7 +157,24 @@ public void testDetectedExtractedFields_GivenInclusionsAndExclusions() { .build(); FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, fieldCapabilities); + ExtractedFields extractedFields = + DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, EMPTY_RESULT_SET, fieldCapabilities); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2"))); + } + + public void testDetectedExtractedFields_GivenAResultField() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("outlier_score", "float") + .addAggregatableField("my_field1", "float") + .addAggregatableField("your_field2", "float") + .addAggregatableField("your_keyword", "keyword") + .build(); + ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, + EMPTY_CONTEXT, + Collections.singleton("outlier_score"), + fieldCapabilities); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2"))); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java new file mode 100644 index 0000000000000..b7e23902ce2b2 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java @@ -0,0 +1,410 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.analysis.common.CommonAnalysisPlugin; +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.index.reindex.ReindexPlugin; +import org.elasticsearch.painless.PainlessPlugin; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.transport.Netty4Plugin; +import org.elasticsearch.xpack.core.XPackClientPlugin; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.ml.LocalStateMachineLearning; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsFields; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcess; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessConfig; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; +import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsResult; +import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; +import org.junit.Assert; +import org.junit.Before; + +import java.io.IOException; +import java.time.ZonedDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.times; + +public class DataFrameAnalyticsManagerIT extends BaseMlIntegTestCase { + + private volatile boolean finished; + private DataFrameAnalyticsConfigProvider provider; + private static double EXPECTED_OUTLIER_SCORE = 42.0; + @Before + public void fieldSetup() { + provider = new DataFrameAnalyticsConfigProvider(client()); + finished = false; + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateMachineLearning.class, CommonAnalysisPlugin.class, + ReindexPlugin.class, PainlessPlugin.class); + } + + @Override + protected Collection> transportClientPlugins() { + return Arrays.asList(XPackClientPlugin.class, Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class); + } + + public void testTaskContinuationFromReindexState() throws Exception { + internalCluster().ensureAtLeastNumDataNodes(1); + ensureStableCluster(1); + String sourceIndex = "test-outlier-detection-from-reindex-state"; + createIndexForAnalysis(sourceIndex); + String id = "test_outlier_detection_from_reindex_state"; + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); + putDataFrameAnalyticsConfig(config); + List results = buildExpectedResults(sourceIndex); + + DataFrameAnalyticsManager manager = createManager(results); + + DataFrameAnalyticsTask task = buildMockedTask(config.getId()); + manager.execute(task, DataFrameAnalyticsState.REINDEXING); + + // wait for markAsCompleted() or markAsFailed() to be called; + assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); + + // Check we've got all docs + SearchResponse searchResponse = client().prepareSearch(config.getDest()).setTrackTotalHits(true).get(); + Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); + for(SearchHit hit : searchResponse.getHits().getHits()) { + Map src = hit.getSourceAsMap(); + assertNotNull(src.get("outlier_score")); + assertThat(src.get("outlier_score"), equalTo(EXPECTED_OUTLIER_SCORE)); + } + + verify(task, never()).markAsFailed(any(Exception.class)); + verify(task, times(1)).markAsCompleted(); + } + + public void testTaskContinuationFromReindexStateWithPreviousResultsIndex() throws Exception { + internalCluster().ensureAtLeastNumDataNodes(1); + ensureStableCluster(1); + String sourceIndex = "test-outlier-detection-from-reindex-state-with-results"; + createIndexForAnalysis(sourceIndex); + String id = "test_outlier_detection_from_reindex_state_with_results"; + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); + putDataFrameAnalyticsConfig(config); + + // Create the "results" index, as if we ran reindex already in the process, but did not transition from the state properly + createAnalysesResultsIndex(config.getDest(), false); + List results = buildExpectedResults(sourceIndex); + + DataFrameAnalyticsManager manager = createManager(results); + + DataFrameAnalyticsTask task = buildMockedTask(config.getId()); + manager.execute(task, DataFrameAnalyticsState.REINDEXING); + + // wait for markAsCompleted() or markAsFailed() to be called; + assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); + + // Check we've got all docs + SearchResponse searchResponse = client().prepareSearch(config.getDest()).setTrackTotalHits(true).get(); + Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); + for(SearchHit hit : searchResponse.getHits().getHits()) { + Map src = hit.getSourceAsMap(); + assertNotNull(src.get("outlier_score")); + assertThat(src.get("outlier_score"), equalTo(EXPECTED_OUTLIER_SCORE)); + } + + verify(task, never()).markAsFailed(any(Exception.class)); + verify(task, times(1)).markAsCompleted(); + } + + public void testTaskContinuationFromAnalyzeState() throws Exception { + internalCluster().ensureAtLeastNumDataNodes(1); + ensureStableCluster(1); + String sourceIndex = "test-outlier-detection-from-analyze-state"; + createIndexForAnalysis(sourceIndex); + String id = "test_outlier_detection_from_analyze_state"; + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); + putDataFrameAnalyticsConfig(config); + // Create the "results" index to simulate running reindex already and having partially ran analysis + createAnalysesResultsIndex(config.getDest(), true); + List results = buildExpectedResults(sourceIndex); + + DataFrameAnalyticsManager manager = createManager(results); + + DataFrameAnalyticsTask task = buildMockedTask(config.getId()); + manager.execute(task, DataFrameAnalyticsState.ANALYZING); + + // wait for markAsCompleted() or markAsFailed() to be called; + assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); + + // Check we've got all docs + SearchResponse searchResponse = client().prepareSearch(config.getDest()).setTrackTotalHits(true).get(); + Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); + for(SearchHit hit : searchResponse.getHits().getHits()) { + Map src = hit.getSourceAsMap(); + assertNotNull(src.get("outlier_score")); + assertThat(src.get("outlier_score"), equalTo(EXPECTED_OUTLIER_SCORE)); + } + + verify(task, never()).markAsFailed(any(Exception.class)); + verify(task, times(1)).markAsCompleted(); + // Need to verify that we did not reindex again, as we already had the full destination index + verify(task, never()).setReindexingTaskId(anyLong()); + } + + private synchronized void completed() { + finished = true; + } + + private synchronized boolean isCompleted() { + return finished; + } + + private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String sourceIndex) { + DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); + configBuilder.setSource(sourceIndex); + configBuilder.setDest(sourceIndex + "-results"); + Map analysisConfig = new HashMap<>(); + analysisConfig.put("outlier_detection", Collections.emptyMap()); + configBuilder.setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig(analysisConfig))); + return configBuilder.build(); + } + + @SuppressWarnings("unchecked") + private void putDataFrameAnalyticsConfig(DataFrameAnalyticsConfig config) throws Exception { + PlainActionFuture future = new PlainActionFuture(); + provider.put(config, Collections.emptyMap(), future); + future.get(); + } + + private void createIndexForAnalysis(String indexName) { + client().admin().indices().prepareCreate(indexName) + .addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical_1", "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (int i = 0; i < 5; i++) { + IndexRequest indexRequest = new IndexRequest(indexName); + + // We insert one odd value out of 5 for one feature + String docId = i == 0 ? "outlier" : "normal" + i; + indexRequest.id(docId); + indexRequest.source("numeric_1", i == 0 ? 100.0 : 1.0, "numeric_2", 1.0, "categorical_1", "foo_" + i); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + Assert.fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } + + private void createAnalysesResultsIndex(String indexName, boolean includeOutlierScore) { + client().admin().indices().prepareCreate(indexName) + .addMapping("_doc", + "numeric_1", "type=double", + "numeric_2", "type=float", + "categorical_1", "type=keyword", + DataFrameAnalyticsFields.ID, "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (int i = 0; i < 5; i++) { + IndexRequest indexRequest = new IndexRequest(indexName); + + // We insert one odd value out of 5 for one feature + String docId = i == 0 ? "outlier" : "normal" + i; + indexRequest.id(docId); + if (includeOutlierScore) { + indexRequest.source("numeric_1", i == 0 ? 100.0 : 1.0, + "numeric_2", 1.0, + "categorical_1", "foo_" + i, + DataFrameAnalyticsFields.ID, docId, + "outlier_score", 10.0); // simply needs to be a score different than expected + } else { + indexRequest.source("numeric_1", i == 0 ? 100.0 : 1.0, + "numeric_2", 1.0, + "categorical_1", "foo_" + i, + DataFrameAnalyticsFields.ID, docId); + } + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + Assert.fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + } + + private List buildExpectedResults(String index) throws Exception { + SearchHit[] hits = client().search(new SearchRequest(index)).get().getHits().getHits(); + Arrays.sort(hits, Comparator.comparing(SearchHit::getId)); + List results = new ArrayList<>(hits.length); + for (SearchHit hit : hits) { + String[] fields = new String[2]; + Map src = hit.getSourceAsMap(); + fields[0] = src.get("numeric_1").toString(); + fields[1] = src.get("numeric_2").toString(); + results.add(new AnalyticsResult(new RowResults(Arrays.hashCode(fields), + Collections.singletonMap("outlier_score", EXPECTED_OUTLIER_SCORE)),null)); + } + return results; + } + + private DataFrameAnalyticsManager createManager(List expectedResults) { + AnalyticsProcessFactory factory = new MockedAnalyticsFactory(expectedResults); + AnalyticsProcessManager processManager = new AnalyticsProcessManager(client(), + TestEnvironment.newEnvironment(internalCluster().getDefaultSettings()), + client().threadPool(), + factory); + return new DataFrameAnalyticsManager(clusterService(), (NodeClient)internalCluster().dataNodeClient(), provider, processManager); + } + + @SuppressWarnings("unchecked") + private DataFrameAnalyticsTask buildMockedTask(String id) { + StartDataFrameAnalyticsAction.TaskParams params = new StartDataFrameAnalyticsAction.TaskParams(id); + DataFrameAnalyticsTask task = mock(DataFrameAnalyticsTask.class); + when(task.getParams()).thenReturn(params); + when(task.getAllocationId()).thenReturn(1L); + doAnswer(invoked -> { + client().threadPool().executor("listener").execute(() -> { + ActionListener listener = (ActionListener) invoked.getArguments()[1]; + final PersistentTasksCustomMetaData.PersistentTask resp = new PersistentTasksCustomMetaData.PersistentTask<>(id, + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, + params, + 1, + new PersistentTasksCustomMetaData.Assignment(null, "none")); + listener.onResponse(resp); + }); + return null; + }).when(task).updatePersistentTaskState(any(DataFrameAnalyticsTaskState.class), any(ActionListener.class)); + doNothing().when(task).setReindexingTaskId(anyLong()); + doAnswer(invoked -> { + completed(); + return null; + }).when(task).markAsCompleted(); + doAnswer(invoked -> { + completed(); + Exception e = (Exception)invoked.getArguments()[0]; + fail(e.getMessage()); + return null; + }).when(task).markAsFailed(any(Exception.class)); + return task; + } + + class MockedAnalyticsFactory implements AnalyticsProcessFactory { + final List results; + + MockedAnalyticsFactory(List resultsToSupply) { + this.results = resultsToSupply; + } + @Override + public AnalyticsProcess createAnalyticsProcess(String jobId, + AnalyticsProcessConfig analyticsProcessConfig, + ExecutorService executorService) { + return new MockedAnalyticsProcess(results); + } + } + + class MockedAnalyticsProcess implements AnalyticsProcess { + + final List results; + final ZonedDateTime start; + MockedAnalyticsProcess(List resultsToSupply) { + results = resultsToSupply; + start = ZonedDateTime.now(); + } + + @Override + public void writeEndOfDataMessage() throws IOException { } + + @Override + public Iterator readAnalyticsResults() { + return results.iterator(); + } + + @Override + public void consumeAndCloseOutputStream() { } + + @Override + public boolean isReady() { + return true; + } + + @Override + public void writeRecord(String[] record) throws IOException { } + + @Override + public void persistState() throws IOException { } + + @Override + public void flushStream() throws IOException { } + + @Override + public void kill() throws IOException { } + + @Override + public ZonedDateTime getProcessStartTime() { + return start; + } + + @Override + public boolean isProcessAlive() { + return true; + } + + @Override + public boolean isProcessAliveAfterWaiting() { + return false; + } + + @Override + public String readError() { + return null; + } + + @Override + public void close() throws IOException { } + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java index a95f341ed1512..42c722b2b0c04 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java @@ -29,8 +29,11 @@ import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.CloseJobAction; +import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetJobsAction; @@ -40,6 +43,8 @@ import org.elasticsearch.xpack.core.ml.client.MachineLearningClient; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -206,6 +211,7 @@ public void cleanupWorkaround() throws Exception { logger.info("[{}#{}]: Cleaning up datafeeds and jobs after test", getTestClass().getSimpleName(), getTestName()); deleteAllDatafeeds(logger, client()); deleteAllJobs(logger, client()); + deleteAllDataFrameAnalytics(client()); assertBusy(() -> { RecoveryResponse recoveryResponse = client().admin().indices().prepareRecoveries() .setActiveOnly(true) @@ -349,6 +355,21 @@ public static void deleteAllJobs(Logger logger, Client client) throws Exception } } + public static void deleteAllDataFrameAnalytics(Client client) throws Exception { + final QueryPage analytics = + client.execute(GetDataFrameAnalyticsAction.INSTANCE, + new GetDataFrameAnalyticsAction.Request("_all")).get().getResources(); + + assertBusy(() -> { + GetDataFrameAnalyticsStatsAction.Response statsResponse = + client().execute(GetDataFrameAnalyticsStatsAction.INSTANCE, new GetDataFrameAnalyticsStatsAction.Request("_all")).get(); + assertTrue(statsResponse.getResponse().results().stream().allMatch(s -> s.getState().equals(DataFrameAnalyticsState.STOPPED))); + }); + for (final DataFrameAnalyticsConfig config : analytics.results()) { + client.execute(DeleteDataFrameAnalyticsAction.INSTANCE, new DeleteDataFrameAnalyticsAction.Request(config.getId())).actionGet(); + } + } + protected String awaitJobOpenedAndAssigned(String jobId, String queryNode) throws Exception { AtomicReference jobNode = new AtomicReference<>(); assertBusy(() -> { From 7637656efeb6a6bf0a5141e76c078eac1faa37ec Mon Sep 17 00:00:00 2001 From: David Roberts Date: Tue, 5 Mar 2019 17:53:59 +0000 Subject: [PATCH 24/67] [FEATURE][ML] Add model_memory_limit to data frame analytics (#39561) The way this works is as close as possible to how it works for anomaly detector jobs. This is because anomaly detector jobs and data frame analytics will share the same memory driven node allocation process so it makes sense for the limiting functionality to be the same too. --- .../xpack/core/ml/MachineLearningField.java | 2 +- .../dataframe/DataFrameAnalyticsConfig.java | 80 +++++++++++++++++-- .../DataFrameAnalyticsConfigTests.java | 67 ++++++++++++++++ .../ml/qa/ml-with-security/build.gradle | 1 + .../TransportPutDataFrameAnalyticsAction.java | 35 +++++--- .../process/AnalyticsProcessManager.java | 6 +- .../test/ml/data_frame_analytics_crud.yml | 59 ++++++++++++++ 7 files changed, 227 insertions(+), 23 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java index 6b5ba086c6fe0..5c3da41df7349 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java @@ -13,7 +13,7 @@ public final class MachineLearningField { public static final Setting AUTODETECT_PROCESS = Setting.boolSetting("xpack.ml.autodetect_process", true, Setting.Property.NodeScope); public static final Setting MAX_MODEL_MEMORY_LIMIT = - Setting.memorySizeSetting("xpack.ml.max_model_memory_limit", new ByteSizeValue(0), + Setting.memorySizeSetting("xpack.ml.max_model_memory_limit", ByteSizeValue.ZERO, Setting.Property.Dynamic, Setting.Property.NodeScope); public static final TimeValue STATE_PERSIST_RESTORE_TIMEOUT = TimeValue.timeValueMinutes(30); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index 9948b899e2907..ef58fd9c3fde7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -14,6 +14,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -37,12 +39,16 @@ import java.util.Objects; import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING; +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.VALUE; public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { private static final Logger logger = LogManager.getLogger(DataFrameAnalyticsConfig.class); public static final String TYPE = "data_frame_analytics_config"; + public static final ByteSizeValue DEFAULT_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.GB); + public static final ByteSizeValue MIN_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.MB); + private static final XContentObjectTransformer QUERY_TRANSFORMER = XContentObjectTransformer.queryBuilderTransformer(); static final TriFunction, String, List, QueryBuilder> lazyQueryParser = (objectMap, id, warnings) -> { @@ -68,6 +74,7 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { public static final ParseField CONFIG_TYPE = new ParseField("config_type"); public static final ParseField QUERY = new ParseField("query"); public static final ParseField ANALYSES_FIELDS = new ParseField("analyses_fields"); + public static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); public static final ParseField HEADERS = new ParseField("headers"); public static final ObjectParser STRICT_PARSER = createParser(false); @@ -86,6 +93,8 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel (p, c) -> FetchSourceContext.fromXContent(p), ANALYSES_FIELDS, OBJECT_ARRAY_BOOLEAN_OR_STRING); + parser.declareField(Builder::setModelMemoryLimit, + (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), MODEL_MEMORY_LIMIT.getPreferredName()), MODEL_MEMORY_LIMIT, VALUE); if (ignoreUnknownFields) { // Headers are not parsed by the strict (config) parser, so headers supplied in the _body_ of a REST request will be rejected. // (For config, headers are explicitly transferred from the auth headers by code in the put data frame actions.) @@ -101,10 +110,20 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel private final Map query; private final CachedSupplier querySupplier; private final FetchSourceContext analysesFields; + /** + * This may be null up to the point of persistence, as the relationship with xpack.ml.max_model_memory_limit + * depends on whether the user explicitly set the value or if the default was requested. null indicates + * the default was requested, which in turn means a default higher than the maximum is silently capped. + * A non-null value higher than xpack.ml.max_model_memory_limit will cause a + * validation error even if it is equal to the default value. This behaviour matches what is done in + * {@link org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits}. + */ + private final ByteSizeValue modelMemoryLimit; private final Map headers; public DataFrameAnalyticsConfig(String id, String source, String dest, List analyses, - Map query, Map headers, FetchSourceContext analysesFields) { + Map query, Map headers, ByteSizeValue modelMemoryLimit, + FetchSourceContext analysesFields) { this.id = ExceptionsHelper.requireNonNull(id, ID); this.source = ExceptionsHelper.requireNonNull(source, SOURCE); this.dest = ExceptionsHelper.requireNonNull(dest, DEST); @@ -119,6 +138,7 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); this.analysesFields = analysesFields; + this.modelMemoryLimit = modelMemoryLimit; this.headers = Collections.unmodifiableMap(headers); } @@ -130,6 +150,7 @@ public DataFrameAnalyticsConfig(StreamInput in) throws IOException { this.query = in.readMap(); this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); this.analysesFields = in.readOptionalWriteable(FetchSourceContext::new); + this.modelMemoryLimit = in.readOptionalWriteable(ByteSizeValue::new); this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); } @@ -177,6 +198,10 @@ List getQueryDeprecations(TriFunction, String, List< return deprecations; } + public ByteSizeValue getModelMemoryLimit() { + return modelMemoryLimit != null ? modelMemoryLimit : DEFAULT_MODEL_MEMORY_LIMIT; + } + public Map getHeaders() { return headers; } @@ -195,6 +220,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (analysesFields != null) { builder.field(ANALYSES_FIELDS.getPreferredName(), analysesFields); } + builder.field(MODEL_MEMORY_LIMIT.getPreferredName(), getModelMemoryLimit().getStringRep()); if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { builder.field(HEADERS.getPreferredName(), headers); } @@ -210,6 +236,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeList(analyses); out.writeMap(query); out.writeOptionalWriteable(analysesFields); + out.writeOptionalWriteable(modelMemoryLimit); out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); } @@ -225,12 +252,13 @@ public boolean equals(Object o) { && Objects.equals(analyses, other.analyses) && Objects.equals(query, other.query) && Objects.equals(headers, other.headers) + && Objects.equals(getModelMemoryLimit(), other.getModelMemoryLimit()) && Objects.equals(analysesFields, other.analysesFields); } @Override public int hashCode() { - return Objects.hash(id, source, dest, analyses, query, headers, analysesFields); + return Objects.hash(id, source, dest, analyses, query, headers, getModelMemoryLimit(), analysesFields); } public static String documentId(String id) { @@ -245,6 +273,8 @@ public static class Builder { private List analyses; private Map query = Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()); private FetchSourceContext analysesFields; + private ByteSizeValue modelMemoryLimit; + private ByteSizeValue maxModelMemoryLimit; private Map headers = Collections.emptyMap(); public Builder() {} @@ -253,22 +283,32 @@ public Builder(String id) { setId(id); } - public String getId() { - return id; + public Builder(ByteSizeValue maxModelMemoryLimit) { + this.maxModelMemoryLimit = maxModelMemoryLimit; + } + + public Builder(DataFrameAnalyticsConfig config) { + this(config, null); } - public Builder(DataFrameAnalyticsConfig config) { + public Builder(DataFrameAnalyticsConfig config, ByteSizeValue maxModelMemoryLimit) { this.id = config.id; this.source = config.source; this.dest = config.dest; this.analyses = new ArrayList<>(config.analyses); this.query = new LinkedHashMap<>(config.query); this.headers = new HashMap<>(config.headers); + this.modelMemoryLimit = config.modelMemoryLimit; + this.maxModelMemoryLimit = maxModelMemoryLimit; if (config.analysesFields != null) { - this.analysesFields = new FetchSourceContext(true, config.analysesFields.includes(), config.analysesFields.excludes()); + this.analysesFields = new FetchSourceContext(true, config.analysesFields.includes(), config.analysesFields.excludes()); } } + public String getId() { + return id; + } + public Builder setId(String id) { this.id = ExceptionsHelper.requireNonNull(id, ID); return this; @@ -318,8 +358,34 @@ public Builder setHeaders(Map headers) { return this; } + public Builder setModelMemoryLimit(ByteSizeValue modelMemoryLimit) { + if (modelMemoryLimit != null && modelMemoryLimit.compareTo(MIN_MODEL_MEMORY_LIMIT) < 0) { + throw new IllegalArgumentException("[" + MODEL_MEMORY_LIMIT.getPreferredName() + + "] must be at least [" + MIN_MODEL_MEMORY_LIMIT.getStringRep() + "]"); + } + this.modelMemoryLimit = modelMemoryLimit; + return this; + } + + private void applyMaxModelMemoryLimit() { + + boolean maxModelMemoryIsSet = maxModelMemoryLimit != null && maxModelMemoryLimit.getMb() > 0; + + if (modelMemoryLimit == null) { + // Default is silently capped if higher than limit + if (maxModelMemoryIsSet && DEFAULT_MODEL_MEMORY_LIMIT.compareTo(maxModelMemoryLimit) > 0) { + modelMemoryLimit = maxModelMemoryLimit; + } + } else if (maxModelMemoryIsSet && modelMemoryLimit.compareTo(maxModelMemoryLimit) > 0) { + // Explicit setting higher than limit is an error + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.JOB_CONFIG_MODEL_MEMORY_LIMIT_GREATER_THAN_MAX, + modelMemoryLimit, maxModelMemoryLimit)); + } + } + public DataFrameAnalyticsConfig build() { - return new DataFrameAnalyticsConfig(id, source, dest, analyses, query, headers, analysesFields); + applyMaxModelMemoryLimit(); + return new DataFrameAnalyticsConfig(id, source, dest, analyses, query, headers, modelMemoryLimit, analysesFields); } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index 478b194ed5af6..498fe2a960b62 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -7,10 +7,13 @@ import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -35,9 +38,12 @@ import java.util.Map; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -93,6 +99,9 @@ public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) { generateRandomStringArray(10, 10, false, false), generateRandomStringArray(10, 10, false, false))); } + if (randomBoolean()) { + builder.setModelMemoryLimit(new ByteSizeValue(randomIntBetween(1, 16), randomFrom(ByteSizeUnit.MB, ByteSizeUnit.GB))); + } return builder; } @@ -201,4 +210,62 @@ public void testGetQueryDeprecations() { spiedConfig.getQueryDeprecations(); verify(spiedConfig).getQueryDeprecations(DataFrameAnalyticsConfig.lazyQueryParser); } + + public void testInvalidModelMemoryLimits() { + + DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(); + + // All these are different ways of specifying a limit that is lower than the minimum + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(1048575, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(-1, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(1023, ByteSizeUnit.KB)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.KB)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.MB)))); + } + + public void testNoMemoryCapping() { + + DataFrameAnalyticsConfig uncapped = createRandom("foo"); + + ByteSizeValue unlimited = randomBoolean() ? null : ByteSizeValue.ZERO; + assertThat(uncapped.getModelMemoryLimit(), + equalTo(new DataFrameAnalyticsConfig.Builder(uncapped, unlimited).build().getModelMemoryLimit())); + } + + public void testMemoryCapping() { + + DataFrameAnalyticsConfig defaultLimitConfig = createRandomBuilder("foo").setModelMemoryLimit(null).build(); + + ByteSizeValue maxLimit = new ByteSizeValue(randomIntBetween(500, 1000), ByteSizeUnit.MB); + if (maxLimit.compareTo(defaultLimitConfig.getModelMemoryLimit()) < 0) { + assertThat(maxLimit, + equalTo(new DataFrameAnalyticsConfig.Builder(defaultLimitConfig, maxLimit).build().getModelMemoryLimit())); + } else { + assertThat(defaultLimitConfig.getModelMemoryLimit(), + equalTo(new DataFrameAnalyticsConfig.Builder(defaultLimitConfig, maxLimit).build().getModelMemoryLimit())); + } + } + + public void testExplicitModelMemoryLimitTooHigh() { + + ByteSizeValue configuredLimit = new ByteSizeValue(randomIntBetween(5, 10), ByteSizeUnit.GB); + DataFrameAnalyticsConfig explicitLimitConfig = createRandomBuilder("foo").setModelMemoryLimit(configuredLimit).build(); + + ByteSizeValue maxLimit = new ByteSizeValue(randomIntBetween(500, 1000), ByteSizeUnit.MB); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new DataFrameAnalyticsConfig.Builder(explicitLimitConfig, maxLimit).build()); + assertThat(e.getMessage(), startsWith("model_memory_limit")); + assertThat(e.getMessage(), containsString("must be less than the value of the xpack.ml.max_model_memory_limit setting")); + } + + public void assertTooSmall(IllegalArgumentException e) { + assertThat(e.getMessage(), is("[model_memory_limit] must be at least [1mb]")); + } } diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 1693b92d1e0c1..f033fd4ed3610 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -47,6 +47,7 @@ integTestRunner { 'ml/data_frame_analytics_crud/Test put config given two analyses', 'ml/data_frame_analytics_crud/Test get given missing analytics', 'ml/data_frame_analytics_crud/Test delete given missing config', + 'ml/data_frame_analytics_crud/Test max model memory limit', 'ml/delete_job_force/Test cannot force delete a non-existent job', 'ml/delete_model_snapshot/Test delete snapshot missing snapshotId', 'ml/delete_model_snapshot/Test delete snapshot missing job_id', diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java index d0fbc613896f5..3be814a72fc2d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java @@ -9,9 +9,11 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.license.LicenseUtils; @@ -21,6 +23,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -48,10 +51,12 @@ public class TransportPutDataFrameAnalyticsAction private final SecurityContext securityContext; private final Client client; + private volatile ByteSizeValue maxModelMemoryLimit; + @Inject public TransportPutDataFrameAnalyticsAction(Settings settings, TransportService transportService, ActionFilters actionFilters, XPackLicenseState licenseState, Client client, ThreadPool threadPool, - DataFrameAnalyticsConfigProvider configProvider) { + ClusterService clusterService, DataFrameAnalyticsConfigProvider configProvider) { super(PutDataFrameAnalyticsAction.NAME, transportService, actionFilters, (Supplier) PutDataFrameAnalyticsAction.Request::new); this.licenseState = licenseState; @@ -60,6 +65,14 @@ public TransportPutDataFrameAnalyticsAction(Settings settings, TransportService this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ? new SecurityContext(settings, threadPool.getThreadContext()) : null; this.client = client; + + maxModelMemoryLimit = MachineLearningField.MAX_MODEL_MEMORY_LIMIT.get(settings); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(MachineLearningField.MAX_MODEL_MEMORY_LIMIT, this::setMaxModelMemoryLimit); + } + + private void setMaxModelMemoryLimit(ByteSizeValue maxModelMemoryLimit) { + this.maxModelMemoryLimit = maxModelMemoryLimit; } @Override @@ -70,14 +83,16 @@ protected void doExecute(Task task, PutDataFrameAnalyticsAction.Request request, return; } validateConfig(request.getConfig()); + DataFrameAnalyticsConfig memoryCappedConfig = + new DataFrameAnalyticsConfig.Builder(request.getConfig(), maxModelMemoryLimit).build(); if (licenseState.isAuthAllowed()) { final String username = securityContext.getUser().principal(); RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() - .indices(request.getConfig().getSource()) + .indices(memoryCappedConfig.getSource()) .privileges("read") .build(); RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() - .indices(request.getConfig().getDest()) + .indices(memoryCappedConfig.getDest()) .privileges("read", "index", "create_index") .build(); @@ -88,24 +103,24 @@ protected void doExecute(Task task, PutDataFrameAnalyticsAction.Request request, privRequest.indexPrivileges(sourceIndexPrivileges, destIndexPrivileges); ActionListener privResponseListener = ActionListener.wrap( - r -> handlePrivsResponse(username, request, r, listener), + r -> handlePrivsResponse(username, memoryCappedConfig, r, listener), listener::onFailure); client.execute(HasPrivilegesAction.INSTANCE, privRequest, privResponseListener); } else { - configProvider.put(request.getConfig(), threadPool.getThreadContext().getHeaders(), ActionListener.wrap( - indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())), + configProvider.put(memoryCappedConfig, threadPool.getThreadContext().getHeaders(), ActionListener.wrap( + indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(memoryCappedConfig)), listener::onFailure )); } } - private void handlePrivsResponse(String username, PutDataFrameAnalyticsAction.Request request, + private void handlePrivsResponse(String username, DataFrameAnalyticsConfig memoryCappedConfig, HasPrivilegesResponse response, ActionListener listener) throws IOException { if (response.isCompleteMatch()) { - configProvider.put(request.getConfig(), threadPool.getThreadContext().getHeaders(), ActionListener.wrap( - indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(request.getConfig())), + configProvider.put(memoryCappedConfig, threadPool.getThreadContext().getHeaders(), ActionListener.wrap( + indexResponse -> listener.onResponse(new PutDataFrameAnalyticsAction.Response(memoryCappedConfig)), listener::onFailure )); } else { @@ -119,7 +134,7 @@ private void handlePrivsResponse(String username, PutDataFrameAnalyticsAction.Re listener.onFailure(Exceptions.authorizationError("Cannot create data frame analytics [{}]" + " because user {} lacks permissions on the indices: {}", - request.getConfig().getId(), username, Strings.toString(builder))); + memoryCappedConfig.getId(), username, Strings.toString(builder))); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 2c2c110b76c12..0f8f634c4c18f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -12,8 +12,6 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.unit.ByteSizeUnit; -import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.env.Environment; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; @@ -40,7 +38,6 @@ public class AnalyticsProcessManager { private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class); private final Client client; - private final Environment environment; private final ThreadPool threadPool; private final AnalyticsProcessFactory processFactory; private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); @@ -48,7 +45,6 @@ public class AnalyticsProcessManager { public AnalyticsProcessManager(Client client, Environment environment, ThreadPool threadPool, AnalyticsProcessFactory analyticsProcessFactory) { this.client = Objects.requireNonNull(client); - this.environment = Objects.requireNonNull(environment); this.threadPool = Objects.requireNonNull(threadPool); this.processFactory = Objects.requireNonNull(analyticsProcessFactory); } @@ -155,7 +151,7 @@ private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig conf assert dataFrameAnalyses.size() == 1; AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, - new ByteSizeValue(1, ByteSizeUnit.GB), 1, dataFrameAnalyses.get(0)); + config.getModelMemoryLimit(), 1, dataFrameAnalyses.get(0)); return processConfig; } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index da6c27533c005..ed92f1423db19 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -405,3 +405,62 @@ catch: missing ml.delete_data_frame_analytics: id: "missing_config" + +--- +"Test max model memory limit": + - skip: + features: headers + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + cluster.put_settings: + body: + transient: + xpack.ml.max_model_memory_limit: "20mb" + - match: {transient.xpack.ml.max_model_memory_limit: "20mb"} + + # Explicit request higher than limit is an error + - do: + catch: /model_memory_limit \[8gb\] must be less than the value of the xpack.ml.max_model_memory_limit setting \[20mb\]/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection-with-query" + body: > + { + "source": "index-source", + "dest": "index-dest", + "analyses": [{"outlier_detection":{}}], + "query": {"term" : { "user" : "Kimchy" }}, + "model_memory_limit": "8gb", + "analyses_fields": [ "obj1.*", "obj2.*" ] + } + + # Request using default higher than limit gets silently capped + - do: + ml.put_data_frame_analytics: + id: "simple-outlier-detection-with-query" + body: > + { + "source": "index-source", + "dest": "index-dest", + "analyses": [{"outlier_detection":{}}], + "query": {"term" : { "user" : "Kimchy" }}, + "analyses_fields": [ "obj1.*", "obj2.*" ] + } + - match: { id: "simple-outlier-detection-with-query" } + - match: { source: "index-source" } + - match: { dest: "index-dest" } + - match: { analyses: [{"outlier_detection":{}}] } + - match: { query: {"term" : { "user" : "Kimchy"} } } + - match: { analyses_fields: {"includes" : ["obj1.*", "obj2.*" ], "excludes": [] } } + - match: { model_memory_limit: "20mb" } + + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + cluster.put_settings: + body: + transient: + xpack.ml.max_model_memory_limit: null + - match: {transient: {}} + From dad41c87b98a94d31a849a8fc844483094010f70 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Wed, 6 Mar 2019 14:51:31 +0000 Subject: [PATCH 25/67] Fix build failure after merging master (#39743) --- .../xpack/ml/dataframe/process/AnalyticsProcessManager.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 0f8f634c4c18f..5f6c7d63942ae 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -55,7 +55,7 @@ public void runJob(long taskAllocationId, DataFrameAnalyticsConfig config, DataF DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); AnalyticsProcess process = createProcess(config.getId(), createProcessConfig(config, dataExtractor)); processContextByAllocation.putIfAbsent(taskAllocationId, new ProcessContext()); - ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); + ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(processContextByAllocation.get(taskAllocationId), @@ -136,7 +136,7 @@ private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsPr private AnalyticsProcess createProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig) { // TODO We should rename the thread pool to reflect its more general use now, e.g. JOB_PROCESS_THREAD_POOL_NAME - ExecutorService executorService = threadPool.executor(MachineLearning.AUTODETECT_THREAD_POOL_NAME); + ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, analyticsProcessConfig, executorService); if (process.isProcessAlive() == false) { throw ExceptionsHelper.serverError("Failed to start data frame analytics process"); From 142edfe2052c7b962fbdaa3b9ed29a427ee36b06 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Fri, 15 Mar 2019 15:38:56 +0000 Subject: [PATCH 26/67] [FEATURE] Fix up feature branch after merging master (#40094) --- .../action/AbstractGetResourcesRequest.java | 75 ---------- .../action/AbstractGetResourcesResponse.java | 84 ----------- .../action/GetDataFrameAnalyticsAction.java | 9 +- .../GetDataFrameAnalyticsStatsAction.java | 4 +- .../core/ml/datafeed/DatafeedConfig.java | 3 +- .../core/ml/datafeed/DatafeedUpdate.java | 5 +- .../dataframe/DataFrameAnalyticsConfig.java | 128 ++++++++--------- .../xpack/core/ml/job/messages/Messages.java | 2 +- .../ml/{datafeed => utils}/QueryProvider.java | 25 ++-- ...DataFrameAnalyticsActionResponseTests.java | 19 ++- .../GetDataFrameAnalyticsRequestTests.java | 2 +- ...rameAnalyticsStatsActionResponseTests.java | 2 +- ...etDataFrameAnalyticsStatsRequestTests.java | 2 +- ...DataFrameAnalyticsActionResponseTests.java | 18 +++ .../core/ml/datafeed/DatafeedConfigTests.java | 3 +- .../core/ml/datafeed/DatafeedUpdateTests.java | 3 +- .../DataFrameAnalyticsConfigTests.java | 42 ++---- .../QueryProviderTests.java | 11 +- .../AbstractTransportGetResourcesAction.java | 136 ------------------ .../TransportGetDataFrameAnalyticsAction.java | 20 ++- ...sportGetDataFrameAnalyticsStatsAction.java | 4 +- ...ransportStartDataFrameAnalyticsAction.java | 3 +- .../RestGetDataFrameAnalyticsAction.java | 2 +- .../RestGetDataFrameAnalyticsStatsAction.java | 2 +- 24 files changed, 171 insertions(+), 433 deletions(-) delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesResponse.java rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/{datafeed => utils}/QueryProvider.java (86%) rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/{datafeed => utils}/QueryProviderTests.java (95%) delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/AbstractTransportGetResourcesAction.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java deleted file mode 100644 index 7d287557f5e93..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesRequest.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.action; - -import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xpack.core.ml.action.util.PageParams; - -import java.io.IOException; -import java.util.Objects; - -public abstract class AbstractGetResourcesRequest extends ActionRequest { - - private String resourceId; - private PageParams pageParams = PageParams.defaultParams(); - - public final void setResourceId(String resourceId) { - this.resourceId = resourceId; - } - - public final String getResourceId() { - return resourceId; - } - - public final void setPageParams(PageParams pageParams) { - this.pageParams = pageParams; - } - - public final PageParams getPageParams() { - return pageParams; - } - - @Override - public ActionRequestValidationException validate() { - return null; - } - - @Override - public void readFrom(StreamInput in) throws IOException { - super.readFrom(in); - resourceId = in.readOptionalString(); - pageParams = in.readOptionalWriteable(PageParams::new); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeOptionalString(resourceId); - out.writeOptionalWriteable(pageParams); - } - - @Override - public int hashCode() { - return Objects.hash(resourceId, pageParams); - } - - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; - } - if (obj instanceof AbstractGetResourcesRequest == false) { - return false; - } - AbstractGetResourcesRequest other = (AbstractGetResourcesRequest) obj; - return Objects.equals(resourceId, other.resourceId); - } - - public abstract String getResourceIdField(); -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesResponse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesResponse.java deleted file mode 100644 index 7f7686ff230ba..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/AbstractGetResourcesResponse.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.action; - -import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.StatusToXContentObject; -import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.ml.action.util.QueryPage; - -import java.io.IOException; -import java.util.Objects; - -public abstract class AbstractGetResourcesResponse extends ActionResponse - implements StatusToXContentObject { - - private QueryPage resources; - - protected AbstractGetResourcesResponse() {} - - protected AbstractGetResourcesResponse(QueryPage resources) { - this.resources = Objects.requireNonNull(resources); - } - - public QueryPage getResources() { - return resources; - } - - @Override - public void readFrom(StreamInput in) throws IOException { - super.readFrom(in); - resources = new QueryPage<>(in, getReader()); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - resources.writeTo(out); - } - - @Override - public RestStatus status() { - return RestStatus.OK; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - resources.doXContentBody(builder, params); - builder.endObject(); - return builder; - } - - @Override - public int hashCode() { - return Objects.hash(resources); - } - - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; - } - if (obj instanceof AbstractGetResourcesResponse == false) { - return false; - } - AbstractGetResourcesResponse other = (AbstractGetResourcesResponse) obj; - return Objects.equals(resources, other.resources); - } - - @Override - public final String toString() { - return Strings.toString(this); - } - protected abstract Reader getReader(); -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java index aeee3657604dd..b689da03ee4f0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java @@ -10,7 +10,9 @@ import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import java.io.IOException; @@ -32,10 +34,13 @@ public Response newResponse() { public static class Request extends AbstractGetResourcesRequest { - public Request() {} + public Request() { + setAllowNoResources(true); + } public Request(String id) { setResourceId(id); + setAllowNoResources(true); } public Request(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index 3bf5ac10a9a0f..5b19da2766129 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -22,8 +22,8 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.tasks.Task; -import org.elasticsearch.xpack.core.ml.action.util.PageParams; -import org.elasticsearch.xpack.core.ml.action.util.QueryPage; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java index 3cd071f61aaee..7187f16e173b0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java @@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import org.elasticsearch.xpack.core.ml.utils.time.TimeUtils; @@ -123,7 +124,7 @@ private static ObjectParser createParser(boolean ignoreUnknownFie parser.declareString((builder, val) -> builder.setFrequency(TimeValue.parseTimeValue(val, FREQUENCY.getPreferredName())), FREQUENCY); parser.declareObject(Builder::setQueryProvider, - (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields), + (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields, Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT), QUERY); parser.declareObject(Builder::setAggregationsSafe, (p, c) -> AggProvider.fromXContent(p, ignoreUnknownFields), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java index 78b4e4ec7c2d3..e10ec20714362 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java @@ -22,7 +22,9 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; @@ -53,7 +55,8 @@ public class DatafeedUpdate implements Writeable, ToXContentObject { TimeValue.parseTimeValue(val, DatafeedConfig.QUERY_DELAY.getPreferredName())), DatafeedConfig.QUERY_DELAY); PARSER.declareString((builder, val) -> builder.setFrequency( TimeValue.parseTimeValue(val, DatafeedConfig.FREQUENCY.getPreferredName())), DatafeedConfig.FREQUENCY); - PARSER.declareObject(Builder::setQuery, (p, c) -> QueryProvider.fromXContent(p, false), DatafeedConfig.QUERY); + PARSER.declareObject(Builder::setQuery, (p, c) -> QueryProvider.fromXContent(p, false, Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT), + DatafeedConfig.QUERY); PARSER.declareObject(Builder::setAggregationsSafe, (p, c) -> AggProvider.fromXContent(p, false), DatafeedConfig.AGGREGATIONS); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index ef58fd9c3fde7..dea7689fdc2bc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -7,23 +7,21 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.TriFunction; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.CachedSupplier; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.XContentParseException; -import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; @@ -33,7 +31,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -49,24 +46,6 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { public static final ByteSizeValue DEFAULT_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.GB); public static final ByteSizeValue MIN_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.MB); - private static final XContentObjectTransformer QUERY_TRANSFORMER = XContentObjectTransformer.queryBuilderTransformer(); - static final TriFunction, String, List, QueryBuilder> lazyQueryParser = - (objectMap, id, warnings) -> { - try { - return QUERY_TRANSFORMER.fromMap(objectMap, warnings); - } catch (IOException | XContentParseException exception) { - // Certain thrown exceptions wrap up the real Illegal argument making it hard to determine cause for the user - if (exception.getCause() instanceof IllegalArgumentException) { - throw ExceptionsHelper.badRequestException( - Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, id), exception.getCause()); - } else { - throw ExceptionsHelper.badRequestException( - Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, id), exception); - } - } - }; - - public static final ParseField ID = new ParseField("id"); public static final ParseField SOURCE = new ParseField("source"); public static final ParseField DEST = new ParseField("dest"); @@ -88,7 +67,9 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel parser.declareString(Builder::setSource, SOURCE); parser.declareString(Builder::setDest, DEST); parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES); - parser.declareObject((builder, query) -> builder.setQuery(query, ignoreUnknownFields), (p, c) -> p.mapOrdered(), QUERY); + parser.declareObject(Builder::setQueryProvider, + (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT), + QUERY); parser.declareField(Builder::setAnalysesFields, (p, c) -> FetchSourceContext.fromXContent(p), ANALYSES_FIELDS, @@ -107,8 +88,7 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel private final String source; private final String dest; private final List analyses; - private final Map query; - private final CachedSupplier querySupplier; + private final QueryProvider queryProvider; private final FetchSourceContext analysesFields; /** * This may be null up to the point of persistence, as the relationship with xpack.ml.max_model_memory_limit @@ -122,7 +102,7 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel private final Map headers; public DataFrameAnalyticsConfig(String id, String source, String dest, List analyses, - Map query, Map headers, ByteSizeValue modelMemoryLimit, + QueryProvider queryProvider, Map headers, ByteSizeValue modelMemoryLimit, FetchSourceContext analysesFields) { this.id = ExceptionsHelper.requireNonNull(id, ID); this.source = ExceptionsHelper.requireNonNull(source, SOURCE); @@ -135,8 +115,7 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List 1) { throw new UnsupportedOperationException("Does not yet support multiple analyses"); } - this.query = Collections.unmodifiableMap(query); - this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); + this.queryProvider = ExceptionsHelper.requireNonNull(queryProvider, QUERY); this.analysesFields = analysesFields; this.modelMemoryLimit = modelMemoryLimit; this.headers = Collections.unmodifiableMap(headers); @@ -147,8 +126,7 @@ public DataFrameAnalyticsConfig(StreamInput in) throws IOException { source = in.readString(); dest = in.readString(); analyses = in.readList(DataFrameAnalysisConfig::new); - this.query = in.readMap(); - this.querySupplier = new CachedSupplier<>(() -> lazyQueryParser.apply(query, id, new ArrayList<>())); + this.queryProvider = QueryProvider.fromStream(in); this.analysesFields = in.readOptionalWriteable(FetchSourceContext::new); this.modelMemoryLimit = in.readOptionalWriteable(ByteSizeValue::new); this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); @@ -170,34 +148,56 @@ public List getAnalyses() { return analyses; } - @Nullable - public Map getQuery() { - return query; - } - - @Nullable + /** + * Get the fully parsed query from the semi-parsed stored {@code Map} + * + * @return Fully parsed query + */ public QueryBuilder getParsedQuery() { - return querySupplier.get(); + Exception exception = queryProvider.getParsingException(); + if (exception != null) { + if (exception instanceof RuntimeException) { + throw (RuntimeException) exception; + } else { + throw new ElasticsearchException(queryProvider.getParsingException()); + } + } + return queryProvider.getParsedQuery(); } - public FetchSourceContext getAnalysesFields() { - return analysesFields; + Exception getQueryParsingException() { + return queryProvider == null ? null : queryProvider.getParsingException(); } /** - * Calls the lazy parser and returns any gathered deprecations + * Calls the parser and returns any gathered deprecations + * + * @param namedXContentRegistry XContent registry to transform the lazily parsed query * @return The deprecations from parsing the query */ - List getQueryDeprecations() { - return getQueryDeprecations(lazyQueryParser); - } - - List getQueryDeprecations(TriFunction, String, List, QueryBuilder> parser) { + public List getQueryDeprecations(NamedXContentRegistry namedXContentRegistry) { List deprecations = new ArrayList<>(); - parser.apply(query, id, deprecations); + try { + XContentObjectTransformer.queryBuilderTransformer(namedXContentRegistry).fromMap(queryProvider.getQuery(), + deprecations); + } catch (Exception exception) { + // Certain thrown exceptions wrap up the real Illegal argument making it hard to determine cause for the user + if (exception.getCause() instanceof IllegalArgumentException) { + exception = (Exception) exception.getCause(); + } + throw ExceptionsHelper.badRequestException(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, exception); + } return deprecations; } + public Map getQuery() { + return queryProvider.getQuery(); + } + + public FetchSourceContext getAnalysesFields() { + return analysesFields; + } + public ByteSizeValue getModelMemoryLimit() { return modelMemoryLimit != null ? modelMemoryLimit : DEFAULT_MODEL_MEMORY_LIMIT; } @@ -216,7 +216,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(ToXContentParams.INCLUDE_TYPE, false)) { builder.field(CONFIG_TYPE.getPreferredName(), TYPE); } - builder.field(QUERY.getPreferredName(), query); + builder.field(QUERY.getPreferredName(), queryProvider.getQuery()); if (analysesFields != null) { builder.field(ANALYSES_FIELDS.getPreferredName(), analysesFields); } @@ -234,7 +234,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(source); out.writeString(dest); out.writeList(analyses); - out.writeMap(query); + queryProvider.writeTo(out); out.writeOptionalWriteable(analysesFields); out.writeOptionalWriteable(modelMemoryLimit); out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); @@ -250,7 +250,7 @@ public boolean equals(Object o) { && Objects.equals(source, other.source) && Objects.equals(dest, other.dest) && Objects.equals(analyses, other.analyses) - && Objects.equals(query, other.query) + && Objects.equals(queryProvider, other.queryProvider) && Objects.equals(headers, other.headers) && Objects.equals(getModelMemoryLimit(), other.getModelMemoryLimit()) && Objects.equals(analysesFields, other.analysesFields); @@ -258,7 +258,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(id, source, dest, analyses, query, headers, getModelMemoryLimit(), analysesFields); + return Objects.hash(id, source, dest, analyses, queryProvider, headers, getModelMemoryLimit(), analysesFields); } public static String documentId(String id) { @@ -271,7 +271,7 @@ public static class Builder { private String source; private String dest; private List analyses; - private Map query = Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()); + private QueryProvider queryProvider = QueryProvider.defaultQuery(); private FetchSourceContext analysesFields; private ByteSizeValue modelMemoryLimit; private ByteSizeValue maxModelMemoryLimit; @@ -296,7 +296,7 @@ public Builder(DataFrameAnalyticsConfig config, ByteSizeValue maxModelMemoryLimi this.source = config.source; this.dest = config.dest; this.analyses = new ArrayList<>(config.analyses); - this.query = new LinkedHashMap<>(config.query); + this.queryProvider = new QueryProvider(config.queryProvider); this.headers = new HashMap<>(config.headers); this.modelMemoryLimit = config.modelMemoryLimit; this.maxModelMemoryLimit = maxModelMemoryLimit; @@ -329,22 +329,8 @@ public Builder setAnalyses(List analyses) { return this; } - public Builder setQuery(Map query) { - return setQuery(query, true); - } - - public Builder setQuery(Map query, boolean lenient) { - this.query = ExceptionsHelper.requireNonNull(query, QUERY.getPreferredName()); - try { - QUERY_TRANSFORMER.fromMap(query); - } catch (Exception exception) { - if (lenient) { - logger.warn(Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, id), exception); - } else { - throw ExceptionsHelper.badRequestException( - Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, id), exception); - } - } + public Builder setQueryProvider(QueryProvider queryProvider) { + this.queryProvider = ExceptionsHelper.requireNonNull(queryProvider, QUERY.getPreferredName()); return this; } @@ -385,7 +371,7 @@ private void applyMaxModelMemoryLimit() { public DataFrameAnalyticsConfig build() { applyMaxModelMemoryLimit(); - return new DataFrameAnalyticsConfig(id, source, dest, analyses, query, headers, modelMemoryLimit, analysesFields); + return new DataFrameAnalyticsConfig(id, source, dest, analyses, queryProvider, headers, modelMemoryLimit, analysesFields); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 6eb0eb9616410..1f9f70f803d10 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -50,7 +50,7 @@ public final class Messages { "Datafeed frequency [{0}] must be a multiple of the aggregation interval [{1}]"; public static final String DATAFEED_ID_ALREADY_TAKEN = "A datafeed with id [{0}] already exists"; - public static final String DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT = "Data Frame Analytics config [{0}] query is not parsable"; + public static final String DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT = "Data Frame Analytics config query is not parsable"; public static final String DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER = "No compatible fields could be detected in index [{0}] with name [{1}]"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java similarity index 86% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java index ff6d2f595af81..fda7e87cad983 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.core.ml.datafeed; +package org.elasticsearch.xpack.core.ml.utils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -19,9 +19,6 @@ import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.xpack.core.ml.job.messages.Messages; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; import java.util.Collections; @@ -29,22 +26,22 @@ import java.util.Map; import java.util.Objects; -class QueryProvider implements Writeable, ToXContentObject { +public class QueryProvider implements Writeable, ToXContentObject { - private static final Logger logger = LogManager.getLogger(AggProvider.class); + private static final Logger logger = LogManager.getLogger(QueryProvider.class); private Exception parsingException; private QueryBuilder parsedQuery; private Map query; - static QueryProvider defaultQuery() { + public static QueryProvider defaultQuery() { return new QueryProvider( Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()), QueryBuilders.matchAllQuery(), null); } - static QueryProvider fromXContent(XContentParser parser, boolean lenient) throws IOException { + public static QueryProvider fromXContent(XContentParser parser, boolean lenient, String failureMessage) throws IOException { Map query = parser.mapOrdered(); QueryBuilder parsedQuery = null; Exception exception = null; @@ -56,15 +53,15 @@ static QueryProvider fromXContent(XContentParser parser, boolean lenient) throws } exception = ex; if (lenient) { - logger.warn(Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT, ex); + logger.warn(failureMessage, ex); } else { - throw ExceptionsHelper.badRequestException(Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT, ex); + throw ExceptionsHelper.badRequestException(failureMessage, ex); } } return new QueryProvider(query, parsedQuery, exception); } - static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOException { + public static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOException { return parsedQuery == null ? null : new QueryProvider( @@ -73,7 +70,7 @@ static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOExceptio null); } - static QueryProvider fromStream(StreamInput in) throws IOException { + public static QueryProvider fromStream(StreamInput in) throws IOException { if (in.getVersion().onOrAfter(Version.V_6_7_0)) { // Has our bug fix for query/agg providers return new QueryProvider(in.readMap(), in.readOptionalNamedWriteable(QueryBuilder.class), in.readException()); } else if (in.getVersion().onOrAfter(Version.V_6_6_0)) { // Has the bug, but supports lazy objects @@ -89,7 +86,7 @@ static QueryProvider fromStream(StreamInput in) throws IOException { this.parsingException = parsingException; } - QueryProvider(QueryProvider other) { + public QueryProvider(QueryProvider other) { this(other.query, other.parsedQuery, other.parsingException); } @@ -107,6 +104,8 @@ public void writeTo(StreamOutput out) throws IOException { if (parsingException != null) { // Do we have a parsing error? Throw it if (parsingException instanceof IOException) { throw (IOException) parsingException; + } else if (parsingException instanceof RuntimeException) { + throw (RuntimeException) parsingException; } else { throw new ElasticsearchException(parsingException); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java index e3b7262095abb..0c2a90195d238 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java @@ -5,17 +5,34 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractStreamableTestCase; +import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Response; -import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import java.util.ArrayList; +import java.util.Collections; import java.util.List; public class GetDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } + @Override protected Response createTestInstance() { int listSize = randomInt(10); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java index 48381526b8c4c..438474076c910 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java @@ -7,8 +7,8 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Request; -import org.elasticsearch.xpack.core.ml.action.util.PageParams; public class GetDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java index ed9599b05152a..e01618520f5a8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -7,8 +7,8 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response; -import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java index 8db7d8db7877c..918d04873ef2c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java @@ -7,8 +7,8 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Request; -import org.elasticsearch.xpack.core.ml.action.util.PageParams; public class GetDataFrameAnalyticsStatsRequestTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java index 011044fb96eef..7830f874a4d6e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java @@ -5,12 +5,30 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractStreamableTestCase; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import java.util.Collections; + public class PutDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } + @Override protected Response createTestInstance() { return new Response(DataFrameAnalyticsConfigTests.createRandom(DataFrameAnalyticsConfigTests.randomValidId())); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java index 71491c9227728..72765affaa576 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java @@ -46,6 +46,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig.Mode; import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; @@ -57,7 +58,7 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.core.ml.datafeed.QueryProviderTests.createRandomValidQueryProvider; +import static org.elasticsearch.xpack.core.ml.utils.QueryProviderTests.createRandomValidQueryProvider; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.DATAFEED_AGGREGATIONS_INTERVAL_MUST_BE_GREATER_THAN_ZERO; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java index 62436172d92a5..1d2d8dbfa0dab 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java @@ -38,6 +38,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig.Mode; import org.elasticsearch.xpack.core.ml.job.config.JobTests; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; @@ -47,7 +48,7 @@ import java.util.List; import static org.elasticsearch.xpack.core.ml.datafeed.AggProviderTests.createRandomValidAggProvider; -import static org.elasticsearch.xpack.core.ml.datafeed.QueryProviderTests.createRandomValidQueryProvider; +import static org.elasticsearch.xpack.core.ml.utils.QueryProviderTests.createRandomValidQueryProvider; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index 498fe2a960b62..100f960200c39 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -23,15 +23,16 @@ import org.elasticsearch.common.xcontent.XContentParseException; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.MatchAllQueryBuilder; -import org.elasticsearch.index.query.TermQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; +import java.io.UncheckedIOException; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -41,11 +42,8 @@ import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasEntry; -import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.startsWith; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { @@ -90,9 +88,13 @@ public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) { .setSource(source) .setDest(dest); if (randomBoolean()) { - builder.setQuery( - Collections.singletonMap(TermQueryBuilder.NAME, - Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))), true); + try { + builder.setQueryProvider( + QueryProvider.fromParsedQuery(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)))); + } catch (IOException e) { + // Should never happen + throw new UncheckedIOException(e); + } } if (randomBoolean()) { builder.setAnalysesFields(new FetchSourceContext(true, @@ -130,7 +132,7 @@ public static String randomValidId() { public void testQueryConfigStoresUserInputOnly() throws IOException { try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(NamedXContentRegistry.EMPTY, + .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, MODERN_QUERY_DATA_FRAME_ANALYTICS)) { @@ -139,7 +141,7 @@ public void testQueryConfigStoresUserInputOnly() throws IOException { } try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(NamedXContentRegistry.EMPTY, + .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, MODERN_QUERY_DATA_FRAME_ANALYTICS)) { @@ -150,17 +152,17 @@ public void testQueryConfigStoresUserInputOnly() throws IOException { public void testPastQueryConfigParse() throws IOException { try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(NamedXContentRegistry.EMPTY, + .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); - ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> config.getParsedQuery()); + ElasticsearchException e = expectThrows(ElasticsearchException.class, config::getParsedQuery); assertEquals("[match] query doesn't support multiple fields, found [query] and [type]", e.getMessage()); } try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) - .createParser(NamedXContentRegistry.EMPTY, + .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { @@ -197,20 +199,6 @@ public void testToXContentForInternalStorage() throws IOException { assertThat(parsedConfig.getHeaders().entrySet(), hasSize(0)); } - public void testGetQueryDeprecations() { - DataFrameAnalyticsConfig dataFrame = createTestInstance(); - String deprecationWarning = "Warning"; - List deprecations = dataFrame.getQueryDeprecations((map, id, deprecationlist) -> { - deprecationlist.add(deprecationWarning); - return new BoolQueryBuilder(); - }); - assertThat(deprecations, hasItem(deprecationWarning)); - - DataFrameAnalyticsConfig spiedConfig = spy(dataFrame); - spiedConfig.getQueryDeprecations(); - verify(spiedConfig).getQueryDeprecations(DataFrameAnalyticsConfig.lazyQueryParser); - } - public void testInvalidModelMemoryLimits() { DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProviderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/QueryProviderTests.java similarity index 95% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProviderTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/QueryProviderTests.java index fb6c2e280d975..a0c8174f8123b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProviderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/utils/QueryProviderTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.core.ml.datafeed; +package org.elasticsearch.xpack.core.ml.utils; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; @@ -26,7 +26,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; -import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import java.io.IOException; import java.util.Collections; @@ -68,7 +68,7 @@ protected Writeable.Reader instanceReader() { @Override protected QueryProvider doParseInstance(XContentParser parser) throws IOException { - return QueryProvider.fromXContent(parser, false); + return QueryProvider.fromXContent(parser, false, Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT); } public static QueryProvider createRandomValidQueryProvider() { @@ -91,7 +91,7 @@ public void testEmptyQueryMap() throws IOException { XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry(), DeprecationHandler.THROW_UNSUPPORTED_OPERATION, "{}"); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> QueryProvider.fromXContent(parser, false)); + () -> QueryProvider.fromXContent(parser, false, Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT)); assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST)); assertThat(e.getMessage(), equalTo("Datafeed query is not parsable")); } @@ -152,8 +152,7 @@ public void testSerializationBetweenEagerVersion() throws IOException { new ElasticsearchException("bad parsing")); output.setVersion(Version.V_6_0_0); ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> queryProviderWithEx.writeTo(output)); - assertNotNull(ex.getCause()); - assertThat(ex.getCause().getMessage(), equalTo("bad parsing")); + assertThat(ex.getMessage(), equalTo("bad parsing")); } try (BytesStreamOutput output = new BytesStreamOutput()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/AbstractTransportGetResourcesAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/AbstractTransportGetResourcesAction.java deleted file mode 100644 index 56fa331230f03..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/AbstractTransportGetResourcesAction.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.action; - -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.client.Client; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.regex.Regex; -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.XContentFactory; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.index.query.BoolQueryBuilder; -import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.ml.action.AbstractGetResourcesRequest; -import org.elasticsearch.xpack.core.ml.action.AbstractGetResourcesResponse; -import org.elasticsearch.xpack.core.ml.action.util.QueryPage; -import org.elasticsearch.xpack.ml.utils.MlIndicesUtils; - -import java.io.IOException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.function.Supplier; - -import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; - -public abstract class AbstractTransportGetResourcesAction> - extends HandledTransportAction { - - private static final String ALL = "_all"; - - private Client client; - - protected AbstractTransportGetResourcesAction(String actionName, TransportService transportService, ActionFilters actionFilters, - Supplier request, Client client) { - super(actionName, transportService, actionFilters, request); - this.client = Objects.requireNonNull(client); - } - - protected void searchResources(AbstractGetResourcesRequest request, ActionListener> listener) { - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() - .sort(request.getResourceIdField()) - .from(request.getPageParams().getFrom()) - .size(request.getPageParams().getSize()) - .query(buildQuery(request)); - - SearchRequest searchRequest = new SearchRequest(getIndices()) - .indicesOptions(MlIndicesUtils.addIgnoreUnavailable(SearchRequest.DEFAULT_INDICES_OPTIONS)) - .source(sourceBuilder); - - executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, searchRequest, new ActionListener() { - @Override - public void onResponse(SearchResponse response) { - List docs = new ArrayList<>(); - for (SearchHit hit : response.getHits().getHits()) { - BytesReference docSource = hit.getSourceRef(); - try (InputStream stream = docSource.streamInput(); - XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser( - NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, stream)) { - docs.add(parse(parser)); - } catch (IOException e) { - this.onFailure(e); - } - } - - if (docs.isEmpty() && isConcreteMatch(request.getResourceId())) { - listener.onFailure(notFoundException(request.getResourceId())); - } else { - listener.onResponse(new QueryPage<>(docs, docs.size(), getResultsField())); - } - } - - - @Override - public void onFailure(Exception e) { - listener.onFailure(e); - } - }, - client::search); - } - - private QueryBuilder buildQuery(AbstractGetResourcesRequest request) { - BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); - if (isMatchAll(request.getResourceId()) == false) { - boolQuery.filter(QueryBuilders.wildcardQuery(request.getResourceIdField(), request.getResourceId())); - } - QueryBuilder additionalQuery = additionalQuery(); - if (additionalQuery != null) { - boolQuery.filter(additionalQuery); - } - return boolQuery.hasClauses() ? boolQuery : QueryBuilders.matchAllQuery(); - } - - private static boolean isMatchAll(String resourceId) { - return Strings.isNullOrEmpty(resourceId) || ALL.equals(resourceId) || Regex.isMatchAllPattern(resourceId); - } - - private static boolean isConcreteMatch(String resourceId) { - return isMatchAll(resourceId) == false && Regex.isSimpleMatchPattern(resourceId) == false; - } - - @Nullable - protected QueryBuilder additionalQuery() { - return null; - } - - protected abstract ParseField getResultsField(); - - protected abstract String[] getIndices(); - - protected abstract Resource parse(XContentParser parser); - - protected abstract ResourceNotFoundException notFoundException(String resourceId); -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java index 02083b6c7d45e..fe1d3ac36f138 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsAction.java @@ -12,22 +12,28 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.AbstractTransportGetResourcesAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + public class TransportGetDataFrameAnalyticsAction extends AbstractTransportGetResourcesAction { @Inject - public TransportGetDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters, Client client) { - super(GetDataFrameAnalyticsAction.NAME, transportService, actionFilters, GetDataFrameAnalyticsAction.Request::new, client); + public TransportGetDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters, Client client, + NamedXContentRegistry xContentRegistry) { + super(GetDataFrameAnalyticsAction.NAME, transportService, actionFilters, GetDataFrameAnalyticsAction.Request::new, client, + xContentRegistry); } @Override @@ -63,4 +69,14 @@ protected void doExecute(Task task, GetDataFrameAnalyticsAction.Request request, protected QueryBuilder additionalQuery() { return QueryBuilders.termQuery(DataFrameAnalyticsConfig.CONFIG_TYPE.getPreferredName(), DataFrameAnalyticsConfig.TYPE); } + + @Override + protected String executionOrigin() { + return ML_ORIGIN; + } + + @Override + protected String extractIdFromResource(DataFrameAnalyticsConfig config) { + return config.getId(); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index 5f096df93bf8d..ec6c9371ea405 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -27,11 +27,11 @@ import org.elasticsearch.tasks.TaskResult; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response.Stats; -import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask; @@ -68,7 +68,7 @@ public TransportGetDataFrameAnalyticsStatsAction(TransportService transportServi @Override protected GetDataFrameAnalyticsStatsAction.Response newResponse(GetDataFrameAnalyticsStatsAction.Request request, - List> tasks, + List> tasks, List taskFailures, List nodeFailures) { List stats = new ArrayList<>(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 8e9029c38c46f..e0cc712c631b5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -137,8 +137,7 @@ public void onFailure(Exception e) { // Validate config ActionListener configListener = ActionListener.wrap( - config -> - DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, config, validateListener), + config -> DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, config, validateListener), listener::onFailure ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java index 6f2c89fd09c53..938694065a7ea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java @@ -12,8 +12,8 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; -import org.elasticsearch.xpack.core.ml.action.util.PageParams; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.ml.MachineLearning; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java index a2d2b1ca48e27..8f1781ba75fc3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java @@ -12,8 +12,8 @@ import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; -import org.elasticsearch.xpack.core.ml.action.util.PageParams; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.ml.MachineLearning; From 2f1a60eb1ce12d61971f5a4da35c02872a616353 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Wed, 27 Mar 2019 22:32:29 +0000 Subject: [PATCH 27/67] [FEATURE] Adjust ML memory tracker to include data frame analytics jobs (#40451) This change is a step towards assigning persistent tasks for data frame analytics jobs to nodes based on their memory usage and the memory usage of other ML jobs of all types. Included in this change: 1. ML memory tracker knows about both anomaly detector jobs and data frame analytics jobs 2. Starting a data frame analytics job refreshes the ML memory tracker 3. Deleting a data frame analytics job removes its entry from the ML memory tracker Deferred to a subsequent change: 1. Assigning anomaly detector jobs based on the location of previously assigned data frame analytics jobs as well as previously assigned anomaly detector jobs 2. Assigning data frame analytics jobs based on memory requirement --- .../xpack/ml/MachineLearning.java | 23 +-- ...ansportDeleteDataFrameAnalyticsAction.java | 20 ++- .../ml/action/TransportDeleteJobAction.java | 2 +- .../ml/action/TransportOpenJobAction.java | 7 +- ...ransportStartDataFrameAnalyticsAction.java | 14 +- .../DataFrameDataExtractorFactory.java | 4 +- .../DataFrameAnalyticsConfigProvider.java | 10 ++ .../process/AnalyticsProcessManager.java | 1 - .../xpack/ml/process/MlMemoryTracker.java | 167 ++++++++++++++---- .../ml/process/MlMemoryTrackerTests.java | 62 +++++-- 10 files changed, 228 insertions(+), 82 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 79c4550119f2c..e5880ceb408aa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -481,16 +481,6 @@ public Collection createComponents(Client client, ClusterService cluster DatafeedManager datafeedManager = new DatafeedManager(threadPool, client, clusterService, datafeedJobBuilder, System::currentTimeMillis, auditor, autodetectProcessManager); this.datafeedManager.set(datafeedManager); - MlMemoryTracker memoryTracker = new MlMemoryTracker(settings, clusterService, threadPool, jobManager, jobResultsProvider); - this.memoryTracker.set(memoryTracker); - MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager, - autodetectProcessManager, memoryTracker); - - // This object's constructor attaches to the license state, so there's no need to retain another reference to it - new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager); - - // run node startup tasks - autodetectProcessManager.onNodeStartup(); // Data frame analytics components AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, environment, threadPool, @@ -501,6 +491,19 @@ public Collection createComponents(Client client, ClusterService cluster dataFrameAnalyticsConfigProvider, analyticsProcessManager); this.dataFrameAnalyticsManager.set(dataFrameAnalyticsManager); + // Components shared by anomaly detection and data frame analytics + MlMemoryTracker memoryTracker = new MlMemoryTracker(settings, clusterService, threadPool, jobManager, jobResultsProvider, + dataFrameAnalyticsConfigProvider); + this.memoryTracker.set(memoryTracker); + MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager, + autodetectProcessManager, memoryTracker); + + // This object's constructor attaches to the license state, so there's no need to retain another reference to it + new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager); + + // run node startup tasks + autodetectProcessManager.onNodeStartup(); + return Arrays.asList( mlLifeCycleService, jobResultsProvider, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java index c14f8bef92c0e..51510cfaf733c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -42,14 +43,17 @@ public class TransportDeleteDataFrameAnalyticsAction extends TransportMasterNodeAction { private final Client client; + private final MlMemoryTracker memoryTracker; @Inject public TransportDeleteDataFrameAnalyticsAction(TransportService transportService, ClusterService clusterService, - ThreadPool threadPool, ActionFilters actionFilters, - IndexNameExpressionResolver indexNameExpressionResolver, Client client) { + ThreadPool threadPool, ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, Client client, + MlMemoryTracker memoryTracker) { super(DeleteDataFrameAnalyticsAction.NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, DeleteDataFrameAnalyticsAction.Request::new); this.client = client; + this.memoryTracker = memoryTracker; } @Override @@ -65,21 +69,25 @@ protected AcknowledgedResponse newResponse() { @Override protected void masterOperation(DeleteDataFrameAnalyticsAction.Request request, ClusterState state, ActionListener listener) { + String id = request.getId(); PersistentTasksCustomMetaData tasks = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); - DataFrameAnalyticsState taskState = MlTasks.getDataFrameAnalyticsState(request.getId(), tasks); + DataFrameAnalyticsState taskState = MlTasks.getDataFrameAnalyticsState(id, tasks); if (taskState != DataFrameAnalyticsState.STOPPED) { listener.onFailure(ExceptionsHelper.conflictStatusException("Cannot delete data frame analytics [{}] while its status is [{}]", - request.getId(), taskState)); + id, taskState)); return; } + // We clean up the memory tracker on delete because there is no stop; the task stops by itself + memoryTracker.removeDataFrameAnalyticsJob(id); + DeleteRequest deleteRequest = new DeleteRequest(AnomalyDetectorsIndex.configIndexName()); - deleteRequest.id(DataFrameAnalyticsConfig.documentId(request.getId())); + deleteRequest.id(DataFrameAnalyticsConfig.documentId(id)); deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); executeAsyncWithOrigin(client, ML_ORIGIN, DeleteAction.INSTANCE, deleteRequest, ActionListener.wrap( deleteResponse -> { if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { - listener.onFailure(ExceptionsHelper.missingDataFrameAnalytics(request.getId())); + listener.onFailure(ExceptionsHelper.missingDataFrameAnalytics(id)); return; } assert deleteResponse.getResult() == DocWriteResponse.Result.DELETED; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java index 3f8321fa4b1b8..e5837eb8251b3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java @@ -231,7 +231,7 @@ private void normalDeleteJob(ParentTaskAssigningClient parentTaskClient, DeleteJ String jobId = request.getJobId(); // We clean up the memory tracker on delete rather than close as close is not a master node action - memoryTracker.removeJob(jobId); + memoryTracker.removeAnomalyDetectorJob(jobId); // Step 4. When the job has been removed from the cluster state, return a response // ------- diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java index 162ed33657479..99df655114ca2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java @@ -206,7 +206,7 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j ++numberOfAllocatingJobs; } OpenJobAction.JobParams params = (OpenJobAction.JobParams) assignedTask.getParams(); - Long jobMemoryRequirement = memoryTracker.getJobMemoryRequirement(params.getJobId()); + Long jobMemoryRequirement = memoryTracker.getAnomalyDetectorJobMemoryRequirement(params.getJobId()); if (jobMemoryRequirement == null) { allocateByMemory = false; logger.debug("Falling back to allocating job [{}] by job counts because " + @@ -271,7 +271,7 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j if (allocateByMemory) { if (machineMemory > 0) { long maxMlMemory = machineMemory * maxMachineMemoryPercent / 100; - Long estimatedMemoryFootprint = memoryTracker.getJobMemoryRequirement(jobId); + Long estimatedMemoryFootprint = memoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId); if (estimatedMemoryFootprint != null) { long availableMemory = maxMlMemory - assignedJobMemory; if (estimatedMemoryFootprint > availableMemory) { @@ -450,7 +450,8 @@ public void onFailure(Exception e) { // Tell the job tracker to refresh the memory requirement for this job and all other jobs that have persistent tasks ActionListener getJobHandler = ActionListener.wrap( - response -> memoryTracker.refreshJobMemoryAndAllOthers(jobParams.getJobId(), memoryRequirementRefreshListener), + response -> memoryTracker.refreshAnomalyDetectorJobMemoryAndAllOthers(jobParams.getJobId(), + memoryRequirementRefreshListener), listener::onFailure ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index e0cc712c631b5..697dde824ddb3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -46,6 +46,7 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import java.util.Map; import java.util.Objects; @@ -65,19 +66,21 @@ public class TransportStartDataFrameAnalyticsAction private final Client client; private final PersistentTasksService persistentTasksService; private final DataFrameAnalyticsConfigProvider configProvider; + private final MlMemoryTracker memoryTracker; @Inject public TransportStartDataFrameAnalyticsAction(TransportService transportService, Client client, ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, XPackLicenseState licenseState, IndexNameExpressionResolver indexNameExpressionResolver, PersistentTasksService persistentTasksService, - DataFrameAnalyticsConfigProvider configProvider) { + DataFrameAnalyticsConfigProvider configProvider, MlMemoryTracker memoryTracker) { super(StartDataFrameAnalyticsAction.NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, StartDataFrameAnalyticsAction.Request::new); this.licenseState = licenseState; this.client = client; this.persistentTasksService = persistentTasksService; this.configProvider = configProvider; + this.memoryTracker = memoryTracker; } @Override @@ -129,12 +132,19 @@ public void onFailure(Exception e) { }; // Start persistent task - ActionListener validateListener = ActionListener.wrap( + ActionListener memoryRequirementRefreshListener = ActionListener.wrap( validated -> persistentTasksService.sendStartRequest(MlTasks.dataFrameAnalyticsTaskId(request.getId()), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, taskParams, waitForAnalyticsToStart), listener::onFailure ); + // Tell the job tracker to refresh the memory requirement for this job and all other jobs that have persistent tasks + ActionListener validateListener = ActionListener.wrap( + config -> memoryTracker.addDataFrameAnalyticsJobMemoryAndRefreshAllOthers( + request.getId(), config.getModelMemoryLimit().getBytes(), memoryRequirementRefreshListener), + listener::onFailure + ); + // Validate config ActionListener configListener = ActionListener.wrap( config -> DataFrameDataExtractorFactory.validateConfigAndSourceIndex(client, config, validateListener), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index d8c7b44d496ad..ab8fb96f74e13 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -119,13 +119,13 @@ public static void create(Client client, */ public static void validateConfigAndSourceIndex(Client client, DataFrameAnalyticsConfig config, - ActionListener listener) { + ActionListener listener) { Set resultFields = resolveResultsFields(config); validateIndexAndExtractFields(client, config.getHeaders(), config.getSource(), config.getAnalysesFields(), resultFields, ActionListener.wrap( fields -> { config.getParsedQuery(); // validate query is acceptable - listener.onResponse(true); + listener.onResponse(config); }, listener::onFailure )); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java index 5ae2358bcb2fb..4aefd57fb0eae 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java @@ -104,4 +104,14 @@ public void get(String id, ActionListener listener) { listener::onFailure )); } + + /** + * @param ids a comma separated list of single IDs and/or wildcards + */ + public void getMultiple(String ids, ActionListener> listener) { + GetDataFrameAnalyticsAction.Request request = new GetDataFrameAnalyticsAction.Request(); + request.setResourceId(ids); + executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, request, ActionListener.wrap( + response -> listener.onResponse(response.getResources().results()), listener::onFailure)); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 5f6c7d63942ae..5849b597de688 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -135,7 +135,6 @@ private void writeHeaderRecord(DataFrameDataExtractor dataExtractor, AnalyticsPr } private AnalyticsProcess createProcess(String jobId, AnalyticsProcessConfig analyticsProcessConfig) { - // TODO We should rename the thread pool to reflect its more general use now, e.g. JOB_PROCESS_THREAD_POOL_NAME ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); AnalyticsProcess process = processFactory.createAnalyticsProcess(jobId, analyticsProcessConfig, executorService); if (process.isProcessAlive() == false) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java index 50d2515046a22..29116b320a885 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java @@ -20,9 +20,12 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; @@ -38,35 +41,41 @@ /** * This class keeps track of the memory requirement of ML jobs. * It only functions on the master node - for this reason it should only be used by master node actions. - * The memory requirement for ML jobs can be updated in 3 ways: - * 1. For all open ML jobs (via {@link #asyncRefresh}) - * 2. For all open ML jobs, plus one named ML job that is not open (via {@link #refreshJobMemoryAndAllOthers}) - * 3. For one named ML job (via {@link #refreshJobMemory}) - * In cases 2 and 3 a listener informs the caller when the requested updates are complete. + * The memory requirement for ML jobs can be updated in 4 ways: + * 1. For all open ML data frame analytics jobs and anomaly detector jobs (via {@link #asyncRefresh}) + * 2. For all open/started ML jobs, plus one named ML anomaly detector job that is not open + * (via {@link #refreshAnomalyDetectorJobMemoryAndAllOthers}) + * 3. For all open/started ML jobs, plus one named ML data frame analytics job that is not started + * (via {@link #addDataFrameAnalyticsJobMemoryAndRefreshAllOthers}) + * 4. For one named ML anomaly detector job (via {@link #refreshAnomalyDetectorJobMemory}) + * In cases 2, 3 and 4 a listener informs the caller when the requested updates are complete. */ public class MlMemoryTracker implements LocalNodeMasterListener { private static final Duration RECENT_UPDATE_THRESHOLD = Duration.ofMinutes(1); private final Logger logger = LogManager.getLogger(MlMemoryTracker.class); - private final ConcurrentHashMap memoryRequirementByJob = new ConcurrentHashMap<>(); + private final ConcurrentHashMap memoryRequirementByAnomalyDetectorJob = new ConcurrentHashMap<>(); + private final ConcurrentHashMap memoryRequirementByDataFrameAnalyticsJob = new ConcurrentHashMap<>(); private final List> fullRefreshCompletionListeners = new ArrayList<>(); private final ThreadPool threadPool; private final ClusterService clusterService; private final JobManager jobManager; private final JobResultsProvider jobResultsProvider; + private final DataFrameAnalyticsConfigProvider configProvider; private final Phaser stopPhaser; private volatile boolean isMaster; private volatile Instant lastUpdateTime; private volatile Duration reassignmentRecheckInterval; public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadPool threadPool, JobManager jobManager, - JobResultsProvider jobResultsProvider) { + JobResultsProvider jobResultsProvider, DataFrameAnalyticsConfigProvider configProvider) { this.threadPool = threadPool; this.clusterService = clusterService; this.jobManager = jobManager; this.jobResultsProvider = jobResultsProvider; + this.configProvider = configProvider; this.stopPhaser = new Phaser(1); setReassignmentRecheckInterval(PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings)); clusterService.addLocalNodeMasterListener(this); @@ -88,7 +97,8 @@ public void onMaster() { public void offMaster() { isMaster = false; logger.trace("ML memory tracker off master"); - memoryRequirementByJob.clear(); + memoryRequirementByAnomalyDetectorJob.clear(); + memoryRequirementByDataFrameAnalyticsJob.clear(); lastUpdateTime = null; } @@ -125,19 +135,19 @@ public boolean isRecentlyRefreshed() { } /** - * Get the memory requirement for a job. + * Get the memory requirement for an anomaly detector job. * This method only works on the master node. * @param jobId The job ID. * @return The memory requirement of the job specified by {@code jobId}, * or null if it cannot be calculated. */ - public Long getJobMemoryRequirement(String jobId) { + public Long getAnomalyDetectorJobMemoryRequirement(String jobId) { if (isMaster == false) { return null; } - Long memoryRequirement = memoryRequirementByJob.get(jobId); + Long memoryRequirement = memoryRequirementByAnomalyDetectorJob.get(jobId); if (memoryRequirement != null) { return memoryRequirement; } @@ -146,16 +156,46 @@ public Long getJobMemoryRequirement(String jobId) { } /** - * Remove any memory requirement that is stored for the specified job. - * It doesn't matter if this method is called for a job that doesn't have - * a stored memory requirement. + * Get the memory requirement for a data frame analytics job. + * This method only works on the master node. + * @param id The job ID. + * @return The memory requirement of the job specified by {@code id}, + * or null if it cannot be found. */ - public void removeJob(String jobId) { - memoryRequirementByJob.remove(jobId); + public Long getDataFrameAnalyticsJobMemoryRequirement(String id) { + + if (isMaster == false) { + return null; + } + + Long memoryRequirement = memoryRequirementByDataFrameAnalyticsJob.get(id); + if (memoryRequirement != null) { + return memoryRequirement; + } + + return null; } /** - * Uses a separate thread to refresh the memory requirement for every ML job that has + * Remove any memory requirement that is stored for the specified anomaly detector job. + * It doesn't matter if this method is called for a job that doesn't have a + * stored memory requirement. + */ + public void removeAnomalyDetectorJob(String jobId) { + memoryRequirementByAnomalyDetectorJob.remove(jobId); + } + + /** + * Remove any memory requirement that is stored for the specified data frame analytics + * job. It doesn't matter if this method is called for a job that doesn't have a + * stored memory requirement. + */ + public void removeDataFrameAnalyticsJob(String id) { + memoryRequirementByDataFrameAnalyticsJob.remove(id); + } + + /** + * Uses a separate thread to refresh the memory requirement for every ML anomaly detector job that has * a corresponding persistent task. This method only works on the master node. * @return true if the async refresh is scheduled, and false * if this is not possible for some reason. @@ -188,22 +228,44 @@ public boolean asyncRefresh() { * @param listener Receives the memory requirement of the job specified by {@code jobId}, * or null if it cannot be calculated. */ - public void refreshJobMemoryAndAllOthers(String jobId, ActionListener listener) { + public void refreshAnomalyDetectorJobMemoryAndAllOthers(String jobId, ActionListener listener) { + + if (isMaster == false) { + listener.onResponse(null); + return; + } + + PersistentTasksCustomMetaData persistentTasks = clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + refresh(persistentTasks, + ActionListener.wrap(aVoid -> refreshAnomalyDetectorJobMemory(jobId, listener), listener::onFailure)); + } + + /** + * This refreshes the memory requirement for every ML job that has a corresponding + * persistent task and, in addition, adds the memory requirement of one data frame analytics + * job that doesn't have a persistent task. This method only works on the master node. + * @param id The job ID of the job whose memory requirement is to be added. + * @param mem The memory requirement (in bytes) of the job specified by {@code id}. + * @param listener Called when the refresh is complete or fails. + */ + public void addDataFrameAnalyticsJobMemoryAndRefreshAllOthers(String id, long mem, ActionListener listener) { if (isMaster == false) { listener.onResponse(null); return; } + memoryRequirementByDataFrameAnalyticsJob.put(id, mem); + PersistentTasksCustomMetaData persistentTasks = clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE); - refresh(persistentTasks, ActionListener.wrap(aVoid -> refreshJobMemory(jobId, listener), listener::onFailure)); + refresh(persistentTasks, listener); } /** * This refreshes the memory requirement for every ML job that has a corresponding persistent task. - * It does NOT remove entries for jobs that no longer have a persistent task, because that would - * lead to a race where a job was opened part way through the refresh. (Instead, entries are removed - * when jobs are deleted.) + * It does NOT remove entries for jobs that no longer have a persistent task, because that would lead + * to a race where a job was opened part way through the refresh. (Instead, entries are removed when + * jobs are deleted.) */ void refresh(PersistentTasksCustomMetaData persistentTasks, ActionListener onCompletion) { @@ -230,37 +292,64 @@ void refresh(PersistentTasksCustomMetaData persistentTasks, ActionListener if (persistentTasks == null) { refreshComplete.onResponse(null); } else { - List> mlJobTasks = persistentTasks.tasks().stream() + List> mlDataFrameAnalyticsJobTasks = persistentTasks.tasks().stream() + .filter(task -> MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME.equals(task.getTaskName())).collect(Collectors.toList()); + ActionListener refreshDataFrameAnalyticsJobs = + ActionListener.wrap(aVoid -> refreshAllDataFrameAnalyticsJobTasks(mlDataFrameAnalyticsJobTasks, refreshComplete), + refreshComplete::onFailure); + + List> mlAnomalyDetectorJobTasks = persistentTasks.tasks().stream() .filter(task -> MlTasks.JOB_TASK_NAME.equals(task.getTaskName())).collect(Collectors.toList()); - iterateMlJobTasks(mlJobTasks.iterator(), refreshComplete); + iterateAnomalyDetectorJobTasks(mlAnomalyDetectorJobTasks.iterator(), refreshDataFrameAnalyticsJobs); } } - private void iterateMlJobTasks(Iterator> iterator, - ActionListener refreshComplete) { + private void iterateAnomalyDetectorJobTasks(Iterator> iterator, + ActionListener refreshComplete) { if (iterator.hasNext()) { OpenJobAction.JobParams jobParams = (OpenJobAction.JobParams) iterator.next().getParams(); - refreshJobMemory(jobParams.getJobId(), + refreshAnomalyDetectorJobMemory(jobParams.getJobId(), ActionListener.wrap( // Do the next iteration in a different thread, otherwise stack overflow // can occur if the searches happen to be on the local node, as the huge // chain of listeners are all called in the same thread if only one node // is involved - mem -> threadPool.executor(executorName()).execute(() -> iterateMlJobTasks(iterator, refreshComplete)), + mem -> threadPool.executor(executorName()).execute(() -> iterateAnomalyDetectorJobTasks(iterator, refreshComplete)), refreshComplete::onFailure)); } else { refreshComplete.onResponse(null); } } + private void refreshAllDataFrameAnalyticsJobTasks(List> mlDataFrameAnalyticsJobTasks, + ActionListener listener) { + if (mlDataFrameAnalyticsJobTasks.isEmpty()) { + listener.onResponse(null); + return; + } + + String startedJobIds = mlDataFrameAnalyticsJobTasks.stream() + .map(task -> ((StartDataFrameAnalyticsAction.TaskParams) task.getParams()).getId()).sorted().collect(Collectors.joining(",")); + + configProvider.getMultiple(startedJobIds, ActionListener.wrap( + analyticsConfigs -> { + for (DataFrameAnalyticsConfig analyticsConfig : analyticsConfigs) { + memoryRequirementByDataFrameAnalyticsJob.put(analyticsConfig.getId(), analyticsConfig.getModelMemoryLimit().getBytes()); + } + listener.onResponse(null); + }, + listener::onFailure + )); + } + /** - * Refresh the memory requirement for a single job. + * Refresh the memory requirement for a single anomaly detector job. * This method only works on the master node. * @param jobId The ID of the job to refresh the memory requirement for. * @param listener Receives the job's memory requirement, or null * if it cannot be calculated. */ - public void refreshJobMemory(String jobId, ActionListener listener) { + public void refreshAnomalyDetectorJobMemory(String jobId, ActionListener listener) { if (isMaster == false) { listener.onResponse(null); return; @@ -288,25 +377,25 @@ public void refreshJobMemory(String jobId, ActionListener listener) { jobResultsProvider.getEstablishedMemoryUsage(jobId, null, null, establishedModelMemoryBytes -> { if (establishedModelMemoryBytes <= 0L) { - setJobMemoryToLimit(jobId, phaserListener); + setAnomalyDetectorJobMemoryToLimit(jobId, phaserListener); } else { Long memoryRequirementBytes = establishedModelMemoryBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes(); - memoryRequirementByJob.put(jobId, memoryRequirementBytes); + memoryRequirementByAnomalyDetectorJob.put(jobId, memoryRequirementBytes); phaserListener.onResponse(memoryRequirementBytes); } }, e -> { - logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e); - setJobMemoryToLimit(jobId, phaserListener); + logger.error("[" + jobId + "] failed to calculate anomaly detector job established model memory requirement", e); + setAnomalyDetectorJobMemoryToLimit(jobId, phaserListener); } ); } catch (Exception e) { - logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e); - setJobMemoryToLimit(jobId, phaserListener); + logger.error("[" + jobId + "] failed to calculate anomaly detector job established model memory requirement", e); + setAnomalyDetectorJobMemoryToLimit(jobId, phaserListener); } } - private void setJobMemoryToLimit(String jobId, ActionListener listener) { + private void setAnomalyDetectorJobMemoryToLimit(String jobId, ActionListener listener) { jobManager.getJob(jobId, ActionListener.wrap(job -> { Long memoryLimitMb = (job.getAnalysisLimits() != null) ? job.getAnalysisLimits().getModelMemoryLimit() : null; // Although recent versions of the code enforce a non-null model_memory_limit @@ -316,7 +405,7 @@ private void setJobMemoryToLimit(String jobId, ActionListener listener) { memoryLimitMb = AnalysisLimits.PRE_6_1_DEFAULT_MODEL_MEMORY_LIMIT_MB; } Long memoryRequirementBytes = ByteSizeUnit.MB.toBytes(memoryLimitMb) + Job.PROCESS_MEMORY_OVERHEAD.getBytes(); - memoryRequirementByJob.put(jobId, memoryRequirementBytes); + memoryRequirementByAnomalyDetectorJob.put(jobId, memoryRequirementBytes); listener.onResponse(memoryRequirementBytes); }, e -> { if (e instanceof ResourceNotFoundException) { @@ -325,7 +414,7 @@ private void setJobMemoryToLimit(String jobId, ActionListener listener) { } else { logger.error("[" + jobId + "] failed to get job during ML memory update", e); } - memoryRequirementByJob.remove(jobId); + memoryRequirementByAnomalyDetectorJob.remove(jobId); listener.onResponse(null); })); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java index 1dd2ba923ef00..426a9e0f83984 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java @@ -17,14 +17,18 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; import org.junit.Before; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicReference; @@ -45,6 +49,7 @@ public class MlMemoryTrackerTests extends ESTestCase { private JobManager jobManager; private JobResultsProvider jobResultsProvider; + private DataFrameAnalyticsConfigProvider configProvider; private MlMemoryTracker memoryTracker; @Before @@ -65,7 +70,8 @@ public void setup() { when(threadPool.executor(anyString())).thenReturn(executorService); jobManager = mock(JobManager.class); jobResultsProvider = mock(JobResultsProvider.class); - memoryTracker = new MlMemoryTracker(Settings.EMPTY, clusterService, threadPool, jobManager, jobResultsProvider); + configProvider = mock(DataFrameAnalyticsConfigProvider.class); + memoryTracker = new MlMemoryTracker(Settings.EMPTY, clusterService, threadPool, jobManager, jobResultsProvider, configProvider); } public void testRefreshAll() { @@ -77,14 +83,26 @@ public void testRefreshAll() { memoryTracker.offMaster(); } - int numMlJobTasks = randomIntBetween(2, 5); Map> tasks = new HashMap<>(); - for (int i = 1; i <= numMlJobTasks; ++i) { + + int numAnomalyDetectorJobTasks = randomIntBetween(2, 5); + for (int i = 1; i <= numAnomalyDetectorJobTasks; ++i) { String jobId = "job" + i; - PersistentTasksCustomMetaData.PersistentTask task = makeTestTask(jobId); + PersistentTasksCustomMetaData.PersistentTask task = makeTestAnomalyDetectorTask(jobId); + tasks.put(task.getId(), task); + } + + List allIds = new ArrayList<>(); + int numDataFrameAnalyticsTasks = randomIntBetween(2, 5); + for (int i = 1; i <= numDataFrameAnalyticsTasks; ++i) { + String id = "analytics" + i; + allIds.add(id); + PersistentTasksCustomMetaData.PersistentTask task = makeTestDataFrameAnalyticsTask(id); tasks.put(task.getId(), task); } - PersistentTasksCustomMetaData persistentTasks = new PersistentTasksCustomMetaData(numMlJobTasks, tasks); + + PersistentTasksCustomMetaData persistentTasks = + new PersistentTasksCustomMetaData(numAnomalyDetectorJobTasks + numDataFrameAnalyticsTasks, tasks); doAnswer(invocation -> { @SuppressWarnings("unchecked") @@ -96,16 +114,17 @@ public void testRefreshAll() { memoryTracker.refresh(persistentTasks, ActionListener.wrap(aVoid -> {}, ESTestCase::assertNull)); if (isMaster) { - for (int i = 1; i <= numMlJobTasks; ++i) { + for (int i = 1; i <= numAnomalyDetectorJobTasks; ++i) { String jobId = "job" + i; verify(jobResultsProvider, times(1)).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(), any()); } + verify(configProvider, times(1)).getMultiple(eq(String.join(",", allIds)), any(ActionListener.class)); } else { verify(jobResultsProvider, never()).getEstablishedMemoryUsage(anyString(), any(), any(), any(), any()); } } - public void testRefreshOne() { + public void testRefreshOneAnomalyDetectorJob() { boolean isMaster = randomBoolean(); if (isMaster) { @@ -137,26 +156,26 @@ public void testRefreshOne() { }).when(jobManager).getJob(eq(jobId), any(ActionListener.class)); AtomicReference refreshedMemoryRequirement = new AtomicReference<>(); - memoryTracker.refreshJobMemory(jobId, ActionListener.wrap(refreshedMemoryRequirement::set, ESTestCase::assertNull)); + memoryTracker.refreshAnomalyDetectorJobMemory(jobId, ActionListener.wrap(refreshedMemoryRequirement::set, ESTestCase::assertNull)); if (isMaster) { if (haveEstablishedModelMemory) { assertEquals(Long.valueOf(modelBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes()), - memoryTracker.getJobMemoryRequirement(jobId)); + memoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId)); } else { long expectedModelMemoryLimit = simulateVeryOldJob ? AnalysisLimits.PRE_6_1_DEFAULT_MODEL_MEMORY_LIMIT_MB : recentJobModelMemoryLimitMb; assertEquals(Long.valueOf(ByteSizeUnit.MB.toBytes(expectedModelMemoryLimit) + Job.PROCESS_MEMORY_OVERHEAD.getBytes()), - memoryTracker.getJobMemoryRequirement(jobId)); + memoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId)); } } else { - assertNull(memoryTracker.getJobMemoryRequirement(jobId)); + assertNull(memoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId)); } - assertEquals(memoryTracker.getJobMemoryRequirement(jobId), refreshedMemoryRequirement.get()); + assertEquals(memoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId), refreshedMemoryRequirement.get()); - memoryTracker.removeJob(jobId); - assertNull(memoryTracker.getJobMemoryRequirement(jobId)); + memoryTracker.removeAnomalyDetectorJob(jobId); + assertNull(memoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId)); } public void testStop() { @@ -165,15 +184,22 @@ public void testStop() { memoryTracker.stop(); AtomicReference exception = new AtomicReference<>(); - memoryTracker.refreshJobMemory("job", ActionListener.wrap(ESTestCase::assertNull, exception::set)); + memoryTracker.refreshAnomalyDetectorJobMemory("job", ActionListener.wrap(ESTestCase::assertNull, exception::set)); assertNotNull(exception.get()); assertThat(exception.get(), instanceOf(EsRejectedExecutionException.class)); assertEquals("Couldn't run ML memory update - node is shutting down", exception.get().getMessage()); } - private PersistentTasksCustomMetaData.PersistentTask makeTestTask(String jobId) { - return new PersistentTasksCustomMetaData.PersistentTask<>("job-" + jobId, MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams(jobId), - 0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT); + private PersistentTasksCustomMetaData.PersistentTask makeTestAnomalyDetectorTask(String jobId) { + return new PersistentTasksCustomMetaData.PersistentTask<>(MlTasks.jobTaskId(jobId), MlTasks.JOB_TASK_NAME, + new OpenJobAction.JobParams(jobId), 0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT); + } + + private + PersistentTasksCustomMetaData.PersistentTask makeTestDataFrameAnalyticsTask(String id) { + return new PersistentTasksCustomMetaData.PersistentTask<>(MlTasks.dataFrameAnalyticsTaskId(id), + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, new StartDataFrameAnalyticsAction.TaskParams(id), 0, + PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT); } } From 139b956fd454301085c2dde6f9598a9b92a732a6 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 1 Apr 2019 17:18:58 +0300 Subject: [PATCH 28/67] [FEATURE][ML] Implement data frame evaluate API (#40328) Adds a new API: ``` POST _ml/data_frame/_evaluate ``` which enables evaluating various ML methods. --- .../xpack/core/XPackClientPlugin.java | 2 + .../ml/action/EvaluateDataFrameAction.java | 207 ++++++++ .../ml/dataframe/evaluation/Evaluation.java | 35 ++ .../evaluation/EvaluationMetricResult.java | 20 + .../evaluation/EvaluationResult.java | 20 + .../MetricListEvaluationResult.java | 58 +++ .../MlEvaluationNamedXContentProvider.java | 73 +++ .../AbstractConfusionMatrixMetric.java | 102 ++++ .../evaluation/softclassification/AucRoc.java | 342 +++++++++++++ .../BinarySoftClassification.java | 214 ++++++++ .../softclassification/ConfusionMatrix.java | 163 ++++++ .../softclassification/Precision.java | 91 ++++ .../evaluation/softclassification/Recall.java | 91 ++++ .../ScoreByThresholdResult.java | 63 +++ .../SoftClassificationMetric.java | 60 +++ .../EvaluateDataFrameActionRequestTests.java | 58 +++ .../softclassification/AucRocTests.java | 127 +++++ .../BinarySoftClassificationTests.java | 79 +++ .../ConfusionMatrixTests.java | 79 +++ .../softclassification/PrecisionTests.java | 93 ++++ .../softclassification/RecallTests.java | 93 ++++ .../ml/qa/ml-with-security/build.gradle | 12 + .../smoketest/MlWithSecurityUserRoleIT.java | 33 +- .../xpack/ml/MachineLearning.java | 22 +- .../TransportEvaluateDataFrameAction.java | 53 ++ .../RestEvaluateDataFrameAction.java | 36 ++ .../api/ml.evaluate_data_frame.json | 14 + .../test/ml/evaluate_data_frame.yml | 466 ++++++++++++++++++ 28 files changed, 2697 insertions(+), 9 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationResult.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MetricListEvaluationResult.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRocTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestEvaluateDataFrameAction.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index d03655842d9db..f20ee9808f1c5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -84,6 +84,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction; import org.elasticsearch.xpack.core.ml.action.FlushJobAction; @@ -312,6 +313,7 @@ public List> getClientActions() { GetDataFrameAnalyticsStatsAction.INSTANCE, DeleteDataFrameAnalyticsAction.INSTANCE, StartDataFrameAnalyticsAction.INSTANCE, + EvaluateDataFrameAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java new file mode 100644 index 0000000000000..0b51d097532d3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +public class EvaluateDataFrameAction extends Action { + + public static final EvaluateDataFrameAction INSTANCE = new EvaluateDataFrameAction(); + public static final String NAME = "cluster:monitor/xpack/ml/data_frame/evaluate"; + + private EvaluateDataFrameAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(); + } + + public static class Request extends ActionRequest implements ToXContentObject { + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField EVALUATION = new ParseField("evaluation"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + a -> new Request((List) a[0], (Evaluation) a[1])); + + static { + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDEX); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); + } + + private static Evaluation parseEvaluation(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + Evaluation evaluation = parser.namedObject(Evaluation.class, parser.currentName(), null); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return evaluation; + } + + public static Request parseRequest(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private String[] indices; + private Evaluation evaluation; + + private Request(List indices, Evaluation evaluation) { + setIndices(indices); + setEvaluation(evaluation); + } + + public Request() { + } + + public String[] getIndices() { + return indices; + } + + public final void setIndices(List indices) { + ExceptionsHelper.requireNonNull(indices, INDEX); + if (indices.isEmpty()) { + throw ExceptionsHelper.badRequestException("At least one index must be specified"); + } + this.indices = indices.toArray(new String[indices.size()]); + } + + public Evaluation getEvaluation() { + return evaluation; + } + + public final void setEvaluation(Evaluation evaluation) { + this.evaluation = ExceptionsHelper.requireNonNull(evaluation, EVALUATION); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + indices = in.readStringArray(); + evaluation = in.readNamedWriteable(Evaluation.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeStringArray(indices); + out.writeNamedWriteable(evaluation); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.array(INDEX.getPreferredName(), indices); + builder.startObject(EVALUATION.getPreferredName()); + builder.field(evaluation.getName(), evaluation); + builder.endObject(); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(indices), evaluation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request that = (Request) o; + return Arrays.equals(indices, that.indices) && Objects.equals(evaluation, that.evaluation); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client) { + super(client, INSTANCE, new Request()); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private EvaluationResult result; + + public Response() { + } + + public Response(EvaluationResult result) { + this.result = Objects.requireNonNull(result); + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + this.result = in.readNamedWriteable(EvaluationResult.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeNamedWriteable(result); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(result.getEvaluationName(), result); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(result); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Response other = (Response) obj; + return Objects.equals(result, other.result); + } + + @Override + public final String toString() { + return Strings.toString(this); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java new file mode 100644 index 0000000000000..0089d2e04e894 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.search.builder.SearchSourceBuilder; + +/** + * Defines an evaluation + */ +public interface Evaluation extends ToXContentObject, NamedWriteable { + + /** + * Returns the evaluation name + */ + String getName(); + + /** + * Builds the search required to collect data to compute the evaluation result + */ + SearchSourceBuilder buildSearch(); + + /** + * Computes the evaluation result + * @param searchResponse The search response required to compute the result + * @param listener A listener of the result + */ + void evaluate(SearchResponse searchResponse, ActionListener listener); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java new file mode 100644 index 0000000000000..36b8adf9d4ea3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * The result of an evaluation metric + */ +public interface EvaluationMetricResult extends ToXContentObject, NamedWriteable { + + /** + * Returns the name of the metric + */ + String getName(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationResult.java new file mode 100644 index 0000000000000..60b2701ee14a2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationResult.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * The result of an evaluation + */ +public interface EvaluationResult extends ToXContentObject, NamedWriteable { + + /** + * Returns the name of the evaluation + */ + String getEvaluationName(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MetricListEvaluationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MetricListEvaluationResult.java new file mode 100644 index 0000000000000..cd32e8aaa4594 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MetricListEvaluationResult.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class MetricListEvaluationResult implements EvaluationResult { + + public static final String NAME = "metric_list_evaluation_result"; + + private final String evaluationName; + private final List metrics; + + public MetricListEvaluationResult(String evaluationName, List metrics) { + this.evaluationName = Objects.requireNonNull(evaluationName); + this.metrics = Objects.requireNonNull(metrics); + } + + public MetricListEvaluationResult(StreamInput in) throws IOException { + this.evaluationName = in.readString(); + this.metrics = in.readNamedWriteableList(EvaluationMetricResult.class); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getEvaluationName() { + return evaluationName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(evaluationName); + out.writeList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + for (EvaluationMetricResult metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java new file mode 100644 index 0000000000000..a4b00840d2b1b --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; + +import java.util.ArrayList; +import java.util.List; + +public class MlEvaluationNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + List namedXContent = new ArrayList<>(); + + // Evaluations + namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, + BinarySoftClassification::fromXContent)); + + // Soft classification metrics + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Precision.NAME, Precision::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Recall.NAME, Recall::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, + ConfusionMatrix::fromXContent)); + + return namedXContent; + } + + public List getNamedWriteables() { + List namedWriteables = new ArrayList<>(); + + // Evaluations + namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), + BinarySoftClassification::new)); + + // Evaluation Results + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationResult.class, MetricListEvaluationResult.NAME, + MetricListEvaluationResult::new)); + + // Evaluation Metrics + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), + AucRoc::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), + Precision::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), + Recall::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix::new)); + + // Evaluation Metrics Results + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), + AucRoc.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, + ScoreByThresholdResult::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix.Result::new)); + + return namedWriteables; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java new file mode 100644 index 0000000000000..facdcceea194f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric { + + public static final ParseField AT = new ParseField("at"); + + protected final double[] thresholds; + + protected AbstractConfusionMatrixMetric(double[] thresholds) { + this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT); + if (thresholds.length == 0) { + throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName() + + "] must have at least one value"); + } + for (double threshold : thresholds) { + if (threshold < 0 || threshold > 1.0) { + throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName() + + "] values must be in [0.0, 1.0]"); + } + } + } + + protected AbstractConfusionMatrixMetric(StreamInput in) throws IOException { + this.thresholds = in.readDoubleArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDoubleArray(thresholds); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(AT.getPreferredName(), thresholds); + builder.endObject(); + return builder; + } + + @Override + public final List aggs(String actualField, List classInfos) { + List aggs = new ArrayList<>(); + for (double threshold : thresholds) { + aggs.addAll(aggsAt(actualField, classInfos, threshold)); + } + return aggs; + } + + protected abstract List aggsAt(String labelField, List classInfos, double threshold); + + protected enum Condition { + TP, FP, TN, FN; + } + + protected String aggName(ClassInfo classInfo, double threshold, Condition condition) { + return getMetricName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name(); + } + + protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + switch (condition) { + case TP: + boolQuery.must(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold)); + break; + case FP: + boolQuery.mustNot(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold)); + break; + case TN: + boolQuery.mustNot(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold)); + break; + case FN: + boolQuery.must(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold)); + break; + default: + throw new IllegalArgumentException("Unknown enum value: " + condition); + } + return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java new file mode 100644 index 0000000000000..9125a50cd85c5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java @@ -0,0 +1,342 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.search.aggregations.metrics.Percentiles; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.stream.IntStream; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + * + * This particular implementation is making use of ES aggregations + * to calculate the curve. It then uses the trapezoidal rule to calculate + * the AUC. + * + * In particular, in order to calculate the ROC, we get percentiles of TP + * and FP against the predicted probability. We call those Rate-Threshold + * curves. We then scan ROC points from each Rate-Threshold curve against the + * other using interpolation. This gives us an approximation of the ROC curve + * that has the advantage of being efficient and resilient to some edge cases. + * + * When this is used for multi-class classification, it will calculate the ROC + * curve of each class versus the rest. + */ +public class AucRoc implements SoftClassificationMetric { + + public static final ParseField NAME = new ParseField("auc_roc"); + + public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new AucRoc((Boolean) a[0])); + + static { + PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE); + } + + private static final String PERCENTILES = "percentiles"; + + public static AucRoc fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final boolean includeCurve; + + public AucRoc(Boolean includeCurve) { + this.includeCurve = includeCurve == null ? false : includeCurve; + } + + public AucRoc(StreamInput in) throws IOException { + this.includeCurve = in.readBoolean(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(includeCurve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve); + builder.endObject(); + return builder; + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRoc that = (AucRoc) o; + return Objects.equals(includeCurve, that.includeCurve); + } + + @Override + public int hashCode() { + return Objects.hash(includeCurve); + } + + @Override + public List aggs(String actualField, List classInfos) { + double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray(); + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + AggregationBuilder percentilesForClassValueAgg = AggregationBuilders + .filter(evaluatedLabelAggName(classInfo), classInfo.matchingQuery()) + .subAggregation( + AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles)); + AggregationBuilder percentilesForRestAgg = AggregationBuilders + .filter(restLabelsAggName(classInfo), QueryBuilders.boolQuery().mustNot(classInfo.matchingQuery())) + .subAggregation( + AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles)); + aggs.add(percentilesForClassValueAgg); + aggs.add(percentilesForRestAgg); + } + return aggs; + } + + private String evaluatedLabelAggName(ClassInfo classInfo) { + return getMetricName() + "_" + classInfo.getName(); + } + + private String restLabelsAggName(ClassInfo classInfo) { + return getMetricName() + "_non_" + classInfo.getName(); + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo)); + Filter restAgg = aggs.get(restLabelsAggName(classInfo)); + double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES)); + double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES)); + List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = calculateAucScore(aucRocCurve); + return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()); + } + + private static double[] percentilesArray(Percentiles percentiles) { + double[] result = new double[99]; + percentiles.forEach(percentile -> result[((int) percentile.getPercent()) - 1] = percentile.getValue()); + return result; + } + + /** + * Visible for testing + */ + static List buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) { + assert tpPercentiles.length == fpPercentiles.length; + assert tpPercentiles.length == 99; + + List aucRocCurve = new ArrayList<>(); + aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0)); + aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0)); + RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true); + RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false); + aucRocCurve.addAll(tpCurve.scanPoints(fpCurve)); + aucRocCurve.addAll(fpCurve.scanPoints(tpCurve)); + Collections.sort(aucRocCurve); + return aucRocCurve; + } + + /** + * Visible for testing + */ + static double calculateAucScore(List rocCurve) { + // Calculates AUC based on the trapezoid rule + double aucRoc = 0.0; + for (int i = 1; i < rocCurve.size(); i++) { + AucRocPoint left = rocCurve.get(i - 1); + AucRocPoint right = rocCurve.get(i); + aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2; + } + return aucRoc; + } + + private static class RateThresholdCurve { + + private final double[] percentiles; + private final boolean isTp; + + private RateThresholdCurve(double[] percentiles, boolean isTp) { + this.percentiles = percentiles; + this.isTp = isTp; + } + + private double getRate(int index) { + return 1 - 0.01 * (index + 1); + } + + private double getThreshold(int index) { + return percentiles[index]; + } + + private double interpolateRate(double threshold) { + int binarySearchResult = Arrays.binarySearch(percentiles, threshold); + if (binarySearchResult >= 0) { + return getRate(binarySearchResult); + } else { + int right = (binarySearchResult * -1) -1; + int left = right - 1; + if (right >= percentiles.length) { + return 0.0; + } else if (left < 0) { + return 1.0; + } else { + double rightRate = getRate(right); + double leftRate = getRate(left); + return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate); + } + } + } + + private List scanPoints(RateThresholdCurve againstCurve) { + List points = new ArrayList<>(); + for (int index = 0; index < percentiles.length; index++) { + double rate = getRate(index); + double scannedThreshold = getThreshold(index); + double againstRate = againstCurve.interpolateRate(scannedThreshold); + AucRocPoint point; + if (isTp) { + point = new AucRocPoint(rate, againstRate, scannedThreshold); + } else { + point = new AucRocPoint(againstRate, rate, scannedThreshold); + } + points.add(point); + } + return points; + } + } + + public static final class AucRocPoint implements Comparable, ToXContentObject, Writeable { + double tpr; + double fpr; + double threshold; + + private AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + private AucRocPoint(StreamInput in) throws IOException { + this.tpr = in.readDouble(); + this.fpr = in.readDouble(); + this.threshold = in.readDouble(); + } + + @Override + public int compareTo(AucRocPoint o) { + return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed() + .thenComparing(p -> p.fpr) + .thenComparing(p -> p.tpr) + .compare(this, o); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(tpr); + out.writeDouble(fpr); + out.writeDouble(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("tpr", tpr); + builder.field("fpr", fpr); + builder.field("threshold", threshold); + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + private static double interpolate(double x, double x1, double y1, double x2, double y2) { + return y1 + (x - x1) * (y2 - y1) / (x2 - x1); + } + + public static class Result implements EvaluationMetricResult { + + private final double score; + private final List curve; + + public Result(double score, List curve) { + this.score = score; + this.curve = Objects.requireNonNull(curve); + } + + public Result(StreamInput in) throws IOException { + this.score = in.readDouble(); + this.curve = in.readList(AucRocPoint::new); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(score); + out.writeList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("score", score); + if (curve.isEmpty() == false) { + builder.field("curve", curve); + } + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java new file mode 100644 index 0000000000000..27a922f086c53 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -0,0 +1,214 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MetricListEvaluationResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of binary soft classification methods, e.g. outlier detection. + * This is useful to evaluate problems where a model outputs a probability of whether + * a data frame row belongs to one of two groups. + */ +public class BinarySoftClassification implements Evaluation { + + public static final ParseField NAME = new ParseField("binary_soft_classification"); + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_PROBABILITY_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(SoftClassificationMetric.class, n, null), METRICS); + } + + public static BinarySoftClassification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field where the actual class is marked up. + * The value of this field is assumed to either be 1 or 0, or true or false. + */ + private final String actualField; + + /** + * The field of the predicted probability in [0.0, 1.0]. + */ + private final String predictedProbabilityField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public BinarySoftClassification(String actualField, String predictedProbabilityField, + @Nullable List metrics) { + this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); + this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD); + this.metrics = initMetrics(metrics); + } + + private static List initMetrics(@Nullable List parsedMetrics) { + List metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics; + if (metrics.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); + } + Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getMetricName)); + return metrics; + } + + private static List defaultMetrics() { + List defaultMetrics = new ArrayList<>(4); + defaultMetrics.add(new AucRoc(false)); + defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75))); + defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75))); + defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75))); + return defaultMetrics; + } + + public BinarySoftClassification(StreamInput in) throws IOException { + this.actualField = in.readString(); + this.predictedProbabilityField = in.readString(); + this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualField); + out.writeString(predictedProbabilityField); + out.writeNamedWriteableList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); + + builder.startObject(METRICS.getPreferredName()); + for (SoftClassificationMetric metric : metrics) { + builder.field(metric.getMetricName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BinarySoftClassification that = (BinarySoftClassification) o; + return Objects.equals(actualField, that.actualField) + && Objects.equals(predictedProbabilityField, that.predictedProbabilityField) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedProbabilityField, metrics); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public SearchSourceBuilder buildSearch() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.size(0); + searchSourceBuilder.query(buildQuery()); + for (SoftClassificationMetric metric : metrics) { + List aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo())); + aggs.forEach(searchSourceBuilder::aggregation); + } + return searchSourceBuilder; + } + + private QueryBuilder buildQuery() { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + boolQuery.filter(QueryBuilders.existsQuery(actualField)); + boolQuery.filter(QueryBuilders.existsQuery(predictedProbabilityField)); + return boolQuery; + } + + @Override + public void evaluate(SearchResponse searchResponse, ActionListener listener) { + if (searchResponse.getHits().getTotalHits().value == 0) { + listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, + predictedProbabilityField)); + return; + } + + List results = new ArrayList<>(); + Aggregations aggs = searchResponse.getAggregations(); + BinaryClassInfo binaryClassInfo = new BinaryClassInfo(); + for (SoftClassificationMetric metric : metrics) { + results.add(metric.evaluate(binaryClassInfo, aggs)); + } + listener.onResponse(new MetricListEvaluationResult(NAME.getPreferredName(), results)); + } + + private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo { + + private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": 1 OR true"); + + @Override + public String getName() { + return String.valueOf(true); + } + + @Override + public QueryBuilder matchingQuery() { + return matchingQuery; + } + + @Override + public String getProbabilityField() { + return predictedProbabilityField; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java new file mode 100644 index 0000000000000..54f245962d515 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java @@ -0,0 +1,163 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class ConfusionMatrix extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("confusion_matrix"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new ConfusionMatrix((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static ConfusionMatrix fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public ConfusionMatrix(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public ConfusionMatrix(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrix that = (ConfusionMatrix) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String labelField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FP)); + aggs.add(buildAgg(classInfo, threshold, Condition.TN)); + aggs.add(buildAgg(classInfo, threshold, Condition.FN)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + long[] tp = new long[thresholds.length]; + long[] fp = new long[thresholds.length]; + long[] tn = new long[thresholds.length]; + long[] fn = new long[thresholds.length]; + for (int i = 0; i < thresholds.length; i++) { + Filter tpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TP)); + Filter fpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FP)); + Filter tnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TN)); + Filter fnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FN)); + tp[i] = tpAgg.getDocCount(); + fp[i] = fpAgg.getDocCount(); + tn[i] = tnAgg.getDocCount(); + fn[i] = fnAgg.getDocCount(); + } + return new Result(thresholds, tp, fp, tn, fn); + } + + public static class Result implements EvaluationMetricResult { + + private final double[] thresholds; + private final long[] tp; + private final long[] fp; + private final long[] tn; + private final long[] fn; + + public Result(double[] thresholds, long[] tp, long[] fp, long[] tn, long[] fn) { + assert thresholds.length == tp.length; + assert thresholds.length == fp.length; + assert thresholds.length == tn.length; + assert thresholds.length == fn.length; + this.thresholds = thresholds; + this.tp = tp; + this.fp = fp; + this.tn = tn; + this.fn = fn; + } + + public Result(StreamInput in) throws IOException { + this.thresholds = in.readDoubleArray(); + this.tp = in.readLongArray(); + this.fp = in.readLongArray(); + this.tn = in.readLongArray(); + this.fn = in.readLongArray(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDoubleArray(thresholds); + out.writeLongArray(tp); + out.writeLongArray(fp); + out.writeLongArray(tn); + out.writeLongArray(fn); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + for (int i = 0; i < thresholds.length; i++) { + builder.startObject(String.valueOf(thresholds[i])); + builder.field("tp", tp[i]); + builder.field("fp", fp[i]); + builder.field("tn", tn[i]); + builder.field("fn", fn[i]); + builder.endObject(); + } + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java new file mode 100644 index 0000000000000..d38a52bb203e8 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class Precision extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("precision"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new Precision((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static Precision fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public Precision(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public Precision(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Precision that = (Precision) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String labelField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FP)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + double[] precisions = new double[thresholds.length]; + for (int i = 0; i < precisions.length; i++) { + double threshold = thresholds[i]; + Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); + Filter fpAgg = aggs.get(aggName(classInfo, threshold, Condition.FP)); + long tp = tpAgg.getDocCount(); + long fp = fpAgg.getDocCount(); + precisions[i] = tp + fp == 0 ? 0.0 : (double) tp / (tp + fp); + } + return new ScoreByThresholdResult(NAME.getPreferredName(), thresholds, precisions); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java new file mode 100644 index 0000000000000..5c4ab57241d95 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class Recall extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("recall"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new Recall((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static Recall fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public Recall(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public Recall(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Recall that = (Recall) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String actualField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo: classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FN)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + double[] recalls = new double[thresholds.length]; + for (int i = 0; i < recalls.length; i++) { + double threshold = thresholds[i]; + Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); + Filter fnAgg =aggs.get(aggName(classInfo, threshold, Condition.FN)); + long tp = tpAgg.getDocCount(); + long fn = fnAgg.getDocCount(); + recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn); + } + return new ScoreByThresholdResult(NAME.getPreferredName(), thresholds, recalls); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java new file mode 100644 index 0000000000000..bd6b6e7db25a1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.Objects; + +public class ScoreByThresholdResult implements EvaluationMetricResult { + + public static final String NAME = "score_by_threshold_result"; + + private final String name; + private final double[] thresholds; + private final double[] scores; + + public ScoreByThresholdResult(String name, double[] thresholds, double[] scores) { + assert thresholds.length == scores.length; + this.name = Objects.requireNonNull(name); + this.thresholds = thresholds; + this.scores = scores; + } + + public ScoreByThresholdResult(StreamInput in) throws IOException { + this.name = in.readString(); + this.thresholds = in.readDoubleArray(); + this.scores = in.readDoubleArray(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return name; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeDoubleArray(thresholds); + out.writeDoubleArray(scores); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + for (int i = 0; i < thresholds.length; i++) { + builder.field(String.valueOf(thresholds[i]), scores[i]); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java new file mode 100644 index 0000000000000..dfb256e9b52f2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.util.List; + +public interface SoftClassificationMetric extends ToXContentObject, NamedWriteable { + + /** + * The information of a specific class + */ + interface ClassInfo { + + /** + * Returns the class name + */ + String getName(); + + /** + * Returns a query that matches documents of the class + */ + QueryBuilder matchingQuery(); + + /** + * Returns the field that has the probability to be of the class + */ + String getProbabilityField(); + } + + /** + * Returns the name of the metric (which may differ to the writeable name) + */ + String getMetricName(); + + /** + * Builds the aggregation that collect required data to compute the metric + * @param actualField the field that stores the actual class + * @param classInfos the information of each class to compute the metric for + * @return the aggregations required to compute the metric + */ + List aggs(String actualField, List classInfos); + + /** + * Calculates the metric result for a given class + * @param classInfo the class to calculate the metric for + * @param aggs the aggregations + * @return the metric result + */ + EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs); +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java new file mode 100644 index 0000000000000..e899b7e6642da --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractStreamableXContentTestCase; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Request; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests; + +import java.util.ArrayList; +import java.util.List; + +public class EvaluateDataFrameActionRequestTests extends AbstractStreamableXContentTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected Request createTestInstance() { + Request request = new Request(); + int indicesCount = randomIntBetween(1, 5); + List indices = new ArrayList<>(indicesCount); + for (int i = 0; i < indicesCount; i++) { + indices.add(randomAlphaOfLength(10)); + } + request.setIndices(indices); + request.setEvaluation(BinarySoftClassificationTests.createRandom()); + return request; + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected Request createBlankInstance() { + return new Request(); + } + + @Override + protected Request doParseInstance(XContentParser parser) { + return Request.parseRequest(parser); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRocTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRocTests.java new file mode 100644 index 0000000000000..6f8ca9339715d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRocTests.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; + +public class AucRocTests extends AbstractSerializingTestCase { + + @Override + protected AucRoc doParseInstance(XContentParser parser) throws IOException { + return AucRoc.PARSER.apply(parser, null); + } + + @Override + protected AucRoc createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return AucRoc::new; + } + + public static AucRoc createRandom() { + return new AucRoc(randomBoolean() ? randomBoolean() : null); + } + + public void testCalculateAucScore_GivenZeroPercentiles() { + double[] tpPercentiles = zeroPercentiles(); + double[] fpPercentiles = zeroPercentiles(); + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + assertThat(aucRocScore, closeTo(0.5, 0.01)); + } + + public void testCalculateAucScore_GivenRandomTpPercentilesAndZeroFpPercentiles() { + double[] tpPercentiles = randomPercentiles(); + double[] fpPercentiles = zeroPercentiles(); + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + assertThat(aucRocScore, closeTo(1.0, 0.1)); + } + + public void testCalculateAucScore_GivenZeroTpPercentilesAndRandomFpPercentiles() { + double[] tpPercentiles = zeroPercentiles(); + double[] fpPercentiles = randomPercentiles(); + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + assertThat(aucRocScore, closeTo(0.0, 0.1)); + } + + public void testCalculateAucScore_GivenRandomPercentiles() { + for (int i = 0; i < 20; i++) { + double[] tpPercentiles = randomPercentiles(); + double[] fpPercentiles = randomPercentiles(); + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); + double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve); + + assertThat(aucRocScore, greaterThanOrEqualTo(0.0)); + assertThat(aucRocScore, lessThanOrEqualTo(1.0)); + assertThat(inverseAucRocScore, greaterThanOrEqualTo(0.0)); + assertThat(inverseAucRocScore, lessThanOrEqualTo(1.0)); + assertThat(aucRocScore + inverseAucRocScore, closeTo(1.0, 0.05)); + } + } + + public void testCalculateAucScore_GivenPrecalculated() { + double[] tpPercentiles = new double[99]; + double[] fpPercentiles = new double[99]; + + double[] tpSimplified = new double[] { 0.3, 0.6, 0.5 , 0.8 }; + double[] fpSimplified = new double[] { 0.1, 0.3, 0.5 , 0.5 }; + + for (int i = 0; i < tpPercentiles.length; i++) { + int simplifiedIndex = i / 25; + tpPercentiles[i] = tpSimplified[simplifiedIndex]; + fpPercentiles[i] = fpSimplified[simplifiedIndex]; + } + + List curve = AucRoc.buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = AucRoc.calculateAucScore(curve); + + List inverseCurve = AucRoc.buildAucRocCurve(fpPercentiles, tpPercentiles); + double inverseAucRocScore = AucRoc.calculateAucScore(inverseCurve); + + assertThat(aucRocScore, closeTo(0.8, 0.05)); + assertThat(inverseAucRocScore, closeTo(0.2, 0.05)); + } + + public static double[] zeroPercentiles() { + double[] percentiles = new double[99]; + Arrays.fill(percentiles, 0.0); + return percentiles; + } + + public static double[] randomPercentiles() { + double[] percentiles = new double[99]; + for (int i = 0; i < percentiles.length; i++) { + percentiles[i] = randomDouble(); + } + Arrays.sort(percentiles); + return percentiles; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java new file mode 100644 index 0000000000000..4f17df3536731 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassificationTests.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; + +public class BinarySoftClassificationTests extends AbstractSerializingTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + public static BinarySoftClassification createRandom() { + List metrics = new ArrayList<>(); + if (randomBoolean()) { + metrics.add(AucRocTests.createRandom()); + } + if (randomBoolean()) { + metrics.add(PrecisionTests.createRandom()); + } + if (randomBoolean()) { + metrics.add(RecallTests.createRandom()); + } + if (randomBoolean()) { + metrics.add(ConfusionMatrixTests.createRandom()); + } + if (metrics.isEmpty()) { + // not a good day to play in the lottery; let's add them all + metrics.add(AucRocTests.createRandom()); + metrics.add(PrecisionTests.createRandom()); + metrics.add(RecallTests.createRandom()); + metrics.add(ConfusionMatrixTests.createRandom()); + } + return new BinarySoftClassification(randomAlphaOfLength(10), randomAlphaOfLength(10), metrics); + } + + @Override + protected BinarySoftClassification doParseInstance(XContentParser parser) throws IOException { + return BinarySoftClassification.fromXContent(parser); + } + + @Override + protected BinarySoftClassification createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return BinarySoftClassification::new; + } + + public void testConstructor_GivenEmptyMetrics() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new BinarySoftClassification("foo", "bar", Collections.emptyList())); + assertThat(e.getMessage(), equalTo("[binary_soft_classification] must have one or more metrics")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java new file mode 100644 index 0000000000000..41f78051af420 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrixTests.java @@ -0,0 +1,79 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ConfusionMatrixTests extends AbstractSerializingTestCase { + + @Override + protected ConfusionMatrix doParseInstance(XContentParser parser) throws IOException { + return ConfusionMatrix.fromXContent(parser); + } + + @Override + protected ConfusionMatrix createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ConfusionMatrix::new; + } + + public static ConfusionMatrix createRandom() { + int thresholdsSize = randomIntBetween(1, 3); + List thresholds = new ArrayList<>(thresholdsSize); + for (int i = 0; i < thresholdsSize; i++) { + thresholds.add(randomDouble()); + } + return new ConfusionMatrix(thresholds); + } + + public void testEvaluate() { + SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); + when(classInfo.getName()).thenReturn("foo"); + + Aggregations aggs = new Aggregations(Arrays.asList( + createFilterAgg("confusion_matrix_foo_at_0.25_TP", 1L), + createFilterAgg("confusion_matrix_foo_at_0.25_FP", 2L), + createFilterAgg("confusion_matrix_foo_at_0.25_TN", 3L), + createFilterAgg("confusion_matrix_foo_at_0.25_FN", 4L), + createFilterAgg("confusion_matrix_foo_at_0.5_TP", 5L), + createFilterAgg("confusion_matrix_foo_at_0.5_FP", 6L), + createFilterAgg("confusion_matrix_foo_at_0.5_TN", 7L), + createFilterAgg("confusion_matrix_foo_at_0.5_FN", 8L) + )); + + ConfusionMatrix confusionMatrix = new ConfusionMatrix(Arrays.asList(0.25, 0.5)); + EvaluationMetricResult result = confusionMatrix.evaluate(classInfo, aggs); + + String expected = "{\"0.25\":{\"tp\":1,\"fp\":2,\"tn\":3,\"fn\":4},\"0.5\":{\"tp\":5,\"fp\":6,\"tn\":7,\"fn\":8}}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + private static Filter createFilterAgg(String name, long docCount) { + Filter agg = mock(Filter.class); + when(agg.getName()).thenReturn(name); + when(agg.getDocCount()).thenReturn(docCount); + return agg; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java new file mode 100644 index 0000000000000..c12156c39373e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/PrecisionTests.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class PrecisionTests extends AbstractSerializingTestCase { + + @Override + protected Precision doParseInstance(XContentParser parser) throws IOException { + return Precision.fromXContent(parser); + } + + @Override + protected Precision createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Precision::new; + } + + public static Precision createRandom() { + int thresholdsSize = randomIntBetween(1, 3); + List thresholds = new ArrayList<>(thresholdsSize); + for (int i = 0; i < thresholdsSize; i++) { + thresholds.add(randomDouble()); + } + return new Precision(thresholds); + } + + public void testEvaluate() { + SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); + when(classInfo.getName()).thenReturn("foo"); + + Aggregations aggs = new Aggregations(Arrays.asList( + createFilterAgg("precision_foo_at_0.25_TP", 1L), + createFilterAgg("precision_foo_at_0.25_FP", 4L), + createFilterAgg("precision_foo_at_0.5_TP", 3L), + createFilterAgg("precision_foo_at_0.5_FP", 1L), + createFilterAgg("precision_foo_at_0.75_TP", 5L), + createFilterAgg("precision_foo_at_0.75_FP", 0L) + )); + + Precision precision = new Precision(Arrays.asList(0.25, 0.5, 0.75)); + EvaluationMetricResult result = precision.evaluate(classInfo, aggs); + + String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + public void testEvaluate_GivenZeroTpAndFp() { + SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); + when(classInfo.getName()).thenReturn("foo"); + + Aggregations aggs = new Aggregations(Arrays.asList( + createFilterAgg("precision_foo_at_1.0_TP", 0L), + createFilterAgg("precision_foo_at_1.0_FP", 0L) + )); + + Precision precision = new Precision(Arrays.asList(1.0)); + EvaluationMetricResult result = precision.evaluate(classInfo, aggs); + + String expected = "{\"1.0\":0.0}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + private static Filter createFilterAgg(String name, long docCount) { + Filter agg = mock(Filter.class); + when(agg.getName()).thenReturn(name); + when(agg.getDocCount()).thenReturn(docCount); + return agg; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java new file mode 100644 index 0000000000000..fc85b44f151d4 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/RecallTests.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class RecallTests extends AbstractSerializingTestCase { + + @Override + protected Recall doParseInstance(XContentParser parser) throws IOException { + return Recall.fromXContent(parser); + } + + @Override + protected Recall createTestInstance() { + return createRandom(); + } + + @Override + protected Writeable.Reader instanceReader() { + return Recall::new; + } + + public static Recall createRandom() { + int thresholdsSize = randomIntBetween(1, 3); + List thresholds = new ArrayList<>(thresholdsSize); + for (int i = 0; i < thresholdsSize; i++) { + thresholds.add(randomDouble()); + } + return new Recall(thresholds); + } + + public void testEvaluate() { + SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); + when(classInfo.getName()).thenReturn("foo"); + + Aggregations aggs = new Aggregations(Arrays.asList( + createFilterAgg("recall_foo_at_0.25_TP", 1L), + createFilterAgg("recall_foo_at_0.25_FN", 4L), + createFilterAgg("recall_foo_at_0.5_TP", 3L), + createFilterAgg("recall_foo_at_0.5_FN", 1L), + createFilterAgg("recall_foo_at_0.75_TP", 5L), + createFilterAgg("recall_foo_at_0.75_FN", 0L) + )); + + Recall recall = new Recall(Arrays.asList(0.25, 0.5, 0.75)); + EvaluationMetricResult result = recall.evaluate(classInfo, aggs); + + String expected = "{\"0.25\":0.2,\"0.5\":0.75,\"0.75\":1.0}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + public void testEvaluate_GivenZeroTpAndFp() { + SoftClassificationMetric.ClassInfo classInfo = mock(SoftClassificationMetric.ClassInfo.class); + when(classInfo.getName()).thenReturn("foo"); + + Aggregations aggs = new Aggregations(Arrays.asList( + createFilterAgg("recall_foo_at_1.0_TP", 0L), + createFilterAgg("recall_foo_at_1.0_FN", 0L) + )); + + Recall recall = new Recall(Arrays.asList(1.0)); + EvaluationMetricResult result = recall.evaluate(classInfo, aggs); + + String expected = "{\"1.0\":0.0}"; + assertThat(Strings.toString(result), equalTo(expected)); + } + + private static Filter createFilterAgg(String name, long docCount) { + Filter agg = mock(Filter.class); + when(agg.getName()).thenReturn(name); + when(agg.getDocCount()).thenReturn(docCount); + return agg; + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index f033fd4ed3610..b2dc5bc54f76e 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -48,6 +48,18 @@ integTestRunner { 'ml/data_frame_analytics_crud/Test get given missing analytics', 'ml/data_frame_analytics_crud/Test delete given missing config', 'ml/data_frame_analytics_crud/Test max model memory limit', + 'ml/evaluate_data_frame/Test given missing index', + 'ml/evaluate_data_frame/Test given index does not exist', + 'ml/evaluate_data_frame/Test given missing evaluation', + 'ml/evaluate_data_frame/Test binary_soft_classification given evaluation with emtpy metrics', + 'ml/evaluate_data_frame/Test binary_soft_classification given missing actual_field', + 'ml/evaluate_data_frame/Test binary_soft_classification given missing predicted_probability_field', + 'ml/evaluate_data_frame/Test binary_soft_classification given precision with threshold less than zero', + 'ml/evaluate_data_frame/Test binary_soft_classification given recall with threshold less than zero', + 'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with threshold less than zero', + 'ml/evaluate_data_frame/Test binary_soft_classification given precision with empty thresholds', + 'ml/evaluate_data_frame/Test binary_soft_classification given recall with empty thresholds', + 'ml/evaluate_data_frame/Test binary_soft_classification given confusion_matrix with empty thresholds', 'ml/delete_job_force/Test cannot force delete a non-existent job', 'ml/delete_model_snapshot/Test delete snapshot missing snapshotId', 'ml/delete_model_snapshot/Test delete snapshot missing job_id', diff --git a/x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityUserRoleIT.java b/x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityUserRoleIT.java index 67b72a648db60..803281257d25a 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityUserRoleIT.java +++ b/x-pack/plugin/ml/qa/ml-with-security/src/test/java/org/elasticsearch/smoketest/MlWithSecurityUserRoleIT.java @@ -6,18 +6,30 @@ package org.elasticsearch.smoketest; import com.carrotsearch.randomizedtesting.annotations.Name; - import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.section.DoSection; import org.elasticsearch.test.rest.yaml.section.ExecutableSection; import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.either; public class MlWithSecurityUserRoleIT extends MlWithSecurityIT { + /** + * These are actions that require the monitor role and/or access to the relevant source index. + * ml_user should have both of these in the tests. + */ + private static final List ALLOWED_ACTION_PATTERNS = Arrays.asList( + Pattern.compile("ml\\.get_.*"), + Pattern.compile("ml\\.find_file_structure"), + Pattern.compile("ml\\.evaluate_data_frame") + ); + private final ClientYamlTestCandidate testCandidate; public MlWithSecurityUserRoleIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { @@ -30,14 +42,12 @@ public void test() throws IOException { try { super.test(); - // We should have got here if and only if the only ML endpoints in the test were GETs - // or the find_file_structure API, which is also available to the machine_learning_user - // role + // We should have got here if and only if the only ML endpoints in the test were in the allowed list for (ExecutableSection section : testCandidate.getTestSection().getExecutableSections()) { if (section instanceof DoSection) { - if (((DoSection) section).getApiCallSection().getApi().startsWith("ml.") && - ((DoSection) section).getApiCallSection().getApi().startsWith("ml.get_") == false && - ((DoSection) section).getApiCallSection().getApi().equals("ml.find_file_structure") == false) { + String apiName = ((DoSection) section).getApiCallSection().getApi(); + + if (((DoSection) section).getApiCallSection().getApi().startsWith("ml.") && isAllowed(apiName) == false) { fail("should have failed because of missing role"); } } @@ -50,6 +60,15 @@ public void test() throws IOException { } } + private static boolean isAllowed(String apiName) { + for (Pattern pattern : ALLOWED_ACTION_PATTERNS) { + if (pattern.matcher(apiName).find()) { + return true; + } + } + return false; + } + @Override protected String[] getCredentials() { return new String[]{"ml_user", "x-pack-test-password"}; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index e5880ceb408aa..13ca9cfaa95f6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -69,6 +69,7 @@ import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction; import org.elasticsearch.xpack.core.ml.action.FlushJobAction; @@ -114,6 +115,7 @@ import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction; import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; @@ -129,6 +131,7 @@ import org.elasticsearch.xpack.ml.action.TransportDeleteForecastAction; import org.elasticsearch.xpack.ml.action.TransportDeleteJobAction; import org.elasticsearch.xpack.ml.action.TransportDeleteModelSnapshotAction; +import org.elasticsearch.xpack.ml.action.TransportEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.action.TransportFinalizeJobExecutionAction; import org.elasticsearch.xpack.ml.action.TransportFindFileStructureAction; import org.elasticsearch.xpack.ml.action.TransportFlushJobAction; @@ -225,6 +228,7 @@ import org.elasticsearch.xpack.ml.rest.datafeeds.RestStopDatafeedAction; import org.elasticsearch.xpack.ml.rest.datafeeds.RestUpdateDatafeedAction; import org.elasticsearch.xpack.ml.rest.dataframe.RestDeleteDataFrameAnalyticsAction; +import org.elasticsearch.xpack.ml.rest.dataframe.RestEvaluateDataFrameAction; import org.elasticsearch.xpack.ml.rest.dataframe.RestGetDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.rest.dataframe.RestGetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.ml.rest.dataframe.RestPutDataFrameAnalyticsAction; @@ -612,7 +616,8 @@ public List getRestHandlers(Settings settings, RestController restC new RestGetDataFrameAnalyticsStatsAction(settings, restController), new RestPutDataFrameAnalyticsAction(settings, restController), new RestDeleteDataFrameAnalyticsAction(settings, restController), - new RestStartDataFrameAnalyticsAction(settings, restController) + new RestStartDataFrameAnalyticsAction(settings, restController), + new RestEvaluateDataFrameAction(settings, restController) ); } @@ -676,7 +681,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(GetDataFrameAnalyticsStatsAction.INSTANCE, TransportGetDataFrameAnalyticsStatsAction.class), new ActionHandler<>(PutDataFrameAnalyticsAction.INSTANCE, TransportPutDataFrameAnalyticsAction.class), new ActionHandler<>(DeleteDataFrameAnalyticsAction.INSTANCE, TransportDeleteDataFrameAnalyticsAction.class), - new ActionHandler<>(StartDataFrameAnalyticsAction.INSTANCE, TransportStartDataFrameAnalyticsAction.class) + new ActionHandler<>(StartDataFrameAnalyticsAction.INSTANCE, TransportStartDataFrameAnalyticsAction.class), + new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class) ); } @@ -845,4 +851,16 @@ static long machineMemoryFromStats(OsStats stats) { } return mem; } + + @Override + public List getNamedWriteables() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + return namedWriteables; + } + + @Override + public List getNamedXContent() { + return new MlEvaluationNamedXContentProvider().getNamedXContentParsers(); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java new file mode 100644 index 0000000000000..62aa2efbfd22e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEvaluateDataFrameAction.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationResult; + +public class TransportEvaluateDataFrameAction extends HandledTransportAction { + + private final ThreadPool threadPool; + private final Client client; + + @Inject + public TransportEvaluateDataFrameAction(TransportService transportService, ActionFilters actionFilters, ThreadPool threadPool, + Client client) { + super(EvaluateDataFrameAction.NAME, transportService, actionFilters, EvaluateDataFrameAction.Request::new); + this.threadPool = threadPool; + this.client = client; + } + + @Override + protected void doExecute(Task task, EvaluateDataFrameAction.Request request, + ActionListener listener) { + Evaluation evaluation = request.getEvaluation(); + SearchRequest searchRequest = new SearchRequest(request.getIndices()); + searchRequest.source(evaluation.buildSearch()); + + ActionListener resultListener = ActionListener.wrap( + result -> listener.onResponse(new EvaluateDataFrameAction.Response(result)), + listener::onFailure + ); + + client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + searchResponse -> threadPool.generic().execute(() -> evaluation.evaluate(searchResponse, resultListener)), + listener::onFailure + )); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestEvaluateDataFrameAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestEvaluateDataFrameAction.java new file mode 100644 index 0000000000000..3b514e1283859 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestEvaluateDataFrameAction.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.dataframe; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +public class RestEvaluateDataFrameAction extends BaseRestHandler { + + public RestEvaluateDataFrameAction(Settings settings, RestController controller) { + super(settings); + controller.registerHandler(RestRequest.Method.POST, MachineLearning.BASE_PATH + "data_frame/_evaluate", this); + } + + @Override + public String getName() { + return "ml_evaluate_data_frame_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + EvaluateDataFrameAction.Request request = EvaluateDataFrameAction.Request.parseRequest(restRequest.contentOrSourceParamParser()); + return channel -> client.execute(EvaluateDataFrameAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json new file mode 100644 index 0000000000000..1a4859c796095 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json @@ -0,0 +1,14 @@ +{ + "ml.evaluate_data_frame": { + "methods": [ "POST" ], + "url": { + "path": "/_ml/data_frame/_evaluate", + "paths": [ "/_ml/data_frame/_evaluate" ], + "parts": {} + }, + "body": { + "description" : "The evaluation definition", + "required" : true + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml new file mode 100644 index 0000000000000..3bbb59c205fb9 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -0,0 +1,466 @@ +setup: + + - do: + index: + index: utopia + body: > + { + "is_outlier": false, + "is_outlier_int": 0, + "outlier_score": 0.0 + } + + - do: + index: + index: utopia + body: > + { + "is_outlier": false, + "is_outlier_int": 0, + "outlier_score": 0.2 + } + + - do: + index: + index: utopia + body: > + { + "is_outlier": false, + "is_outlier_int": 0, + "outlier_score": 0.3 + } + + - do: + index: + index: utopia + body: > + { + "is_outlier": true, + "is_outlier_int": 1, + "outlier_score": 0.3 + } + + - do: + index: + index: utopia + body: > + { + "is_outlier": true, + "is_outlier_int": 1, + "outlier_score": 0.4 + } + + - do: + index: + index: utopia + body: > + { + "is_outlier": true, + "is_outlier_int": 1, + "outlier_score": 0.5 + } + + - do: + index: + index: utopia + body: > + { + "is_outlier": true, + "is_outlier_int": 1, + "outlier_score": 0.9 + } + + - do: + index: + index: utopia + body: > + { + "is_outlier": true, + "is_outlier_int": 1, + "outlier_score": 0.95 + } + + # This document misses the required fields and should be ignored + - do: + index: + index: utopia + body: > + { + "foo": 0.24 + } + + - do: + indices.refresh: {} + +--- +"Test binary_soft_classifition auc_roc": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "auc_roc": {} + } + } + } + } + - match: { binary_soft_classification.auc_roc.score: 0.9899 } + - is_false: binary_soft_classification.auc_roc.curve + +--- +"Test binary_soft_classifition auc_roc given actual_field is int": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier_int", + "predicted_probability_field": "outlier_score", + "metrics": { + "auc_roc": {} + } + } + } + } + - match: { binary_soft_classification.auc_roc.score: 0.9899 } + - is_false: binary_soft_classification.auc_roc.curve + +--- +"Test binary_soft_classifition auc_roc include curve": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "auc_roc": { "include_curve": true } + } + } + } + } + - match: { binary_soft_classification.auc_roc.score: 0.9899 } + - is_true: binary_soft_classification.auc_roc.curve + +--- +"Test binary_soft_classifition precision": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "precision": { "at": [0, 0.5] } + } + } + } + } + - match: + binary_soft_classification: + precision: + 0.0: 0.625 + 0.5: 1.0 + +--- +"Test binary_soft_classifition recall": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "recall": { "at": [0, 0.4, 0.5] } + } + } + } + } + - match: + binary_soft_classification: + recall: + 0.0: 1.0 + 0.4: 0.8 + 0.5: 0.6 + +--- +"Test binary_soft_classifition confusion_matrix": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "confusion_matrix": { "at": [0, 0.3, 0.5] } + } + } + } + } + - match: + binary_soft_classification: + confusion_matrix: + 0.0: + tp: 5 + fp: 3 + tn: 0 + fn: 0 + 0.3: + tp: 5 + fp: 1 + tn: 2 + fn: 0 + 0.5: + tp: 3 + fp: 0 + tn: 3 + fn: 2 + +--- +"Test binary_soft_classifition default metrics": + - do: + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score" + } + } + } + - is_true: binary_soft_classification.auc_roc.score + - is_true: binary_soft_classification.precision.0\.25 + - is_true: binary_soft_classification.precision.0\.5 + - is_true: binary_soft_classification.precision.0\.75 + - is_true: binary_soft_classification.recall.0\.25 + - is_true: binary_soft_classification.recall.0\.5 + - is_true: binary_soft_classification.recall.0\.75 + - is_true: binary_soft_classification.confusion_matrix.0\.25 + - is_true: binary_soft_classification.confusion_matrix.0\.5 + - is_true: binary_soft_classification.confusion_matrix.0\.75 + +--- +"Test given missing index": + - do: + catch: /Required \[index\]/ + ml.evaluate_data_frame: + body: > + { + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score" + } + } + } + +--- +"Test given index does not exist": + - do: + catch: missing + ml.evaluate_data_frame: + body: > + { + "index": "missing_index", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score" + } + } + } + +--- +"Test given missing evaluation": + - do: + catch: /Required \[evaluation\]/ + ml.evaluate_data_frame: + body: > + { + "index": "foo" + } + +--- +"Test binary_soft_classification given evaluation with emtpy metrics": + - do: + catch: /\[binary_soft_classification\] must have one or more metrics/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + } + } + } + } + +--- +"Test binary_soft_classification given missing actual_field": + - do: + catch: /No documents found containing both \[missing, outlier_score\] fields/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "missing", + "predicted_probability_field": "outlier_score" + } + } + } + +--- +"Test binary_soft_classification given missing predicted_probability_field": + - do: + catch: /No documents found containing both \[is_outlier, missing\] fields/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "missing" + } + } + } + +--- +"Test binary_soft_classification given precision with threshold less than zero": + - do: + catch: /\[precision.at\] values must be in \[0.0, 1.0\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "precision": { "at": [ 0.25, -0.1 ]} + } + } + } + } + +--- +"Test binary_soft_classification given recall with threshold less than zero": + - do: + catch: /\[recall.at\] values must be in \[0.0, 1.0\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "recall": { "at": [ 0.25, -0.1 ]} + } + } + } + } + +--- +"Test binary_soft_classification given confusion_matrix with threshold less than zero": + - do: + catch: /\[confusion_matrix.at\] values must be in \[0.0, 1.0\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "confusion_matrix": { "at": [ 0.25, -0.1 ]} + } + } + } + } + +--- +"Test binary_soft_classification given precision with empty thresholds": + - do: + catch: /\[precision.at\] must have at least one value/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "precision": { "at": []} + } + } + } + } + +--- +"Test binary_soft_classification given recall with empty thresholds": + - do: + catch: /\[recall.at\] must have at least one value/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "recall": { "at": []} + } + } + } + } + +--- +"Test binary_soft_classification given confusion_matrix with empty thresholds": + - do: + catch: /\[confusion_matrix.at\] must have at least one value/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "is_outlier", + "predicted_probability_field": "outlier_score", + "metrics": { + "confusion_matrix": { "at": []} + } + } + } + } From 7fcd7710cfd8968c60315b5ee0e9927a095dd873 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 4 Apr 2019 12:26:30 +0300 Subject: [PATCH 29/67] [FEATURE][ML] Add source and dest as objects to analytics config (#40698) --- .../dataframe/DataFrameAnalyticsConfig.java | 111 ++------ .../ml/dataframe/DataFrameAnalyticsDest.java | 78 ++++++ .../dataframe/DataFrameAnalyticsSource.java | 144 ++++++++++ .../persistence/ElasticsearchMappings.java | 17 +- .../ml/job/results/ReservedFieldNames.java | 5 + .../DataFrameAnalyticsConfigTests.java | 38 +-- .../DataFrameAnalyticsDestTests.java | 34 +++ .../DataFrameAnalyticsSourceTests.java | 64 +++++ .../ml/qa/ml-with-security/build.gradle | 4 + .../integration/RunDataFrameAnalyticsIT.java | 12 +- .../TransportPutDataFrameAnalyticsAction.java | 10 +- .../dataframe/DataFrameAnalyticsManager.java | 13 +- .../DataFrameDataExtractorFactory.java | 10 +- .../process/AnalyticsProcessManager.java | 2 +- .../DataFrameAnalyticsManagerIT.java | 24 +- .../test/ml/data_frame_analytics_crud.yml | 256 ++++++++++++++---- .../test/ml/start_data_frame_analytics.yml | 16 +- 17 files changed, 625 insertions(+), 213 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index dea7689fdc2bc..54f0fb646ba81 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -5,9 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -15,17 +12,13 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; -import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; -import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; import java.util.ArrayList; @@ -40,7 +33,6 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { - private static final Logger logger = LogManager.getLogger(DataFrameAnalyticsConfig.class); public static final String TYPE = "data_frame_analytics_config"; public static final ByteSizeValue DEFAULT_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.GB); @@ -51,7 +43,6 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { public static final ParseField DEST = new ParseField("dest"); public static final ParseField ANALYSES = new ParseField("analyses"); public static final ParseField CONFIG_TYPE = new ParseField("config_type"); - public static final ParseField QUERY = new ParseField("query"); public static final ParseField ANALYSES_FIELDS = new ParseField("analyses_fields"); public static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); public static final ParseField HEADERS = new ParseField("headers"); @@ -64,12 +55,9 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel parser.declareString((c, s) -> {}, CONFIG_TYPE); parser.declareString(Builder::setId, ID); - parser.declareString(Builder::setSource, SOURCE); - parser.declareString(Builder::setDest, DEST); + parser.declareObject(Builder::setSource, DataFrameAnalyticsSource.createParser(ignoreUnknownFields), SOURCE); + parser.declareObject(Builder::setDest, DataFrameAnalyticsDest.createParser(ignoreUnknownFields), DEST); parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES); - parser.declareObject(Builder::setQueryProvider, - (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT), - QUERY); parser.declareField(Builder::setAnalysesFields, (p, c) -> FetchSourceContext.fromXContent(p), ANALYSES_FIELDS, @@ -85,10 +73,9 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel } private final String id; - private final String source; - private final String dest; + private final DataFrameAnalyticsSource source; + private final DataFrameAnalyticsDest dest; private final List analyses; - private final QueryProvider queryProvider; private final FetchSourceContext analysesFields; /** * This may be null up to the point of persistence, as the relationship with xpack.ml.max_model_memory_limit @@ -101,8 +88,8 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel private final ByteSizeValue modelMemoryLimit; private final Map headers; - public DataFrameAnalyticsConfig(String id, String source, String dest, List analyses, - QueryProvider queryProvider, Map headers, ByteSizeValue modelMemoryLimit, + public DataFrameAnalyticsConfig(String id, DataFrameAnalyticsSource source, DataFrameAnalyticsDest dest, + List analyses, Map headers, ByteSizeValue modelMemoryLimit, FetchSourceContext analysesFields) { this.id = ExceptionsHelper.requireNonNull(id, ID); this.source = ExceptionsHelper.requireNonNull(source, SOURCE); @@ -115,7 +102,6 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List 1) { throw new UnsupportedOperationException("Does not yet support multiple analyses"); } - this.queryProvider = ExceptionsHelper.requireNonNull(queryProvider, QUERY); this.analysesFields = analysesFields; this.modelMemoryLimit = modelMemoryLimit; this.headers = Collections.unmodifiableMap(headers); @@ -123,10 +109,9 @@ public DataFrameAnalyticsConfig(String id, String source, String dest, List getAnalyses() { return analyses; } - /** - * Get the fully parsed query from the semi-parsed stored {@code Map} - * - * @return Fully parsed query - */ - public QueryBuilder getParsedQuery() { - Exception exception = queryProvider.getParsingException(); - if (exception != null) { - if (exception instanceof RuntimeException) { - throw (RuntimeException) exception; - } else { - throw new ElasticsearchException(queryProvider.getParsingException()); - } - } - return queryProvider.getParsedQuery(); - } - - Exception getQueryParsingException() { - return queryProvider == null ? null : queryProvider.getParsingException(); - } - - /** - * Calls the parser and returns any gathered deprecations - * - * @param namedXContentRegistry XContent registry to transform the lazily parsed query - * @return The deprecations from parsing the query - */ - public List getQueryDeprecations(NamedXContentRegistry namedXContentRegistry) { - List deprecations = new ArrayList<>(); - try { - XContentObjectTransformer.queryBuilderTransformer(namedXContentRegistry).fromMap(queryProvider.getQuery(), - deprecations); - } catch (Exception exception) { - // Certain thrown exceptions wrap up the real Illegal argument making it hard to determine cause for the user - if (exception.getCause() instanceof IllegalArgumentException) { - exception = (Exception) exception.getCause(); - } - throw ExceptionsHelper.badRequestException(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, exception); - } - return deprecations; - } - - public Map getQuery() { - return queryProvider.getQuery(); - } - public FetchSourceContext getAnalysesFields() { return analysesFields; } @@ -216,7 +155,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(ToXContentParams.INCLUDE_TYPE, false)) { builder.field(CONFIG_TYPE.getPreferredName(), TYPE); } - builder.field(QUERY.getPreferredName(), queryProvider.getQuery()); if (analysesFields != null) { builder.field(ANALYSES_FIELDS.getPreferredName(), analysesFields); } @@ -231,10 +169,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(id); - out.writeString(source); - out.writeString(dest); + source.writeTo(out); + dest.writeTo(out); out.writeList(analyses); - queryProvider.writeTo(out); out.writeOptionalWriteable(analysesFields); out.writeOptionalWriteable(modelMemoryLimit); out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); @@ -250,7 +187,6 @@ public boolean equals(Object o) { && Objects.equals(source, other.source) && Objects.equals(dest, other.dest) && Objects.equals(analyses, other.analyses) - && Objects.equals(queryProvider, other.queryProvider) && Objects.equals(headers, other.headers) && Objects.equals(getModelMemoryLimit(), other.getModelMemoryLimit()) && Objects.equals(analysesFields, other.analysesFields); @@ -258,7 +194,7 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(id, source, dest, analyses, queryProvider, headers, getModelMemoryLimit(), analysesFields); + return Objects.hash(id, source, dest, analyses, headers, getModelMemoryLimit(), analysesFields); } public static String documentId(String id) { @@ -268,10 +204,9 @@ public static String documentId(String id) { public static class Builder { private String id; - private String source; - private String dest; + private DataFrameAnalyticsSource source; + private DataFrameAnalyticsDest dest; private List analyses; - private QueryProvider queryProvider = QueryProvider.defaultQuery(); private FetchSourceContext analysesFields; private ByteSizeValue modelMemoryLimit; private ByteSizeValue maxModelMemoryLimit; @@ -293,10 +228,9 @@ public Builder(DataFrameAnalyticsConfig config) { public Builder(DataFrameAnalyticsConfig config, ByteSizeValue maxModelMemoryLimit) { this.id = config.id; - this.source = config.source; - this.dest = config.dest; + this.source = new DataFrameAnalyticsSource(config.source); + this.dest = new DataFrameAnalyticsDest(config.dest); this.analyses = new ArrayList<>(config.analyses); - this.queryProvider = new QueryProvider(config.queryProvider); this.headers = new HashMap<>(config.headers); this.modelMemoryLimit = config.modelMemoryLimit; this.maxModelMemoryLimit = maxModelMemoryLimit; @@ -314,12 +248,12 @@ public Builder setId(String id) { return this; } - public Builder setSource(String source) { + public Builder setSource(DataFrameAnalyticsSource source) { this.source = ExceptionsHelper.requireNonNull(source, SOURCE); return this; } - public Builder setDest(String dest) { + public Builder setDest(DataFrameAnalyticsDest dest) { this.dest = ExceptionsHelper.requireNonNull(dest, DEST); return this; } @@ -329,11 +263,6 @@ public Builder setAnalyses(List analyses) { return this; } - public Builder setQueryProvider(QueryProvider queryProvider) { - this.queryProvider = ExceptionsHelper.requireNonNull(queryProvider, QUERY.getPreferredName()); - return this; - } - public Builder setAnalysesFields(FetchSourceContext fields) { this.analysesFields = fields; return this; @@ -371,7 +300,7 @@ private void applyMaxModelMemoryLimit() { public DataFrameAnalyticsConfig build() { applyMaxModelMemoryLimit(); - return new DataFrameAnalyticsConfig(id, source, dest, analyses, queryProvider, headers, modelMemoryLimit, analysesFields); + return new DataFrameAnalyticsConfig(id, source, dest, analyses, headers, modelMemoryLimit, analysesFields); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java new file mode 100644 index 0000000000000..3f3c2636ed3c2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DataFrameAnalyticsDest implements Writeable, ToXContentObject { + + public static final ParseField INDEX = new ParseField("index"); + + public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("data_frame_analytics_dest", + ignoreUnknownFields, a -> new DataFrameAnalyticsDest((String) a[0])); + parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); + return parser; + } + + private final String index; + + public DataFrameAnalyticsDest(String index) { + this.index = ExceptionsHelper.requireNonNull(index, INDEX); + if (index.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must be non-empty", INDEX); + } + } + + public DataFrameAnalyticsDest(StreamInput in) throws IOException { + index = in.readString(); + } + + public DataFrameAnalyticsDest(DataFrameAnalyticsDest other) { + this.index = other.index; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(index); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsDest other = (DataFrameAnalyticsDest) o; + return Objects.equals(index, other.index); + } + + @Override + public int hashCode() { + return Objects.hash(index); + } + + public String getIndex() { + return index; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java new file mode 100644 index 0000000000000..a57de375f3989 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; +import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class DataFrameAnalyticsSource implements Writeable, ToXContentObject { + + public static final ParseField INDEX = new ParseField("index"); + public static final ParseField QUERY = new ParseField("query"); + + public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("data_frame_analytics_source", + ignoreUnknownFields, a -> new DataFrameAnalyticsSource((String) a[0], (QueryProvider) a[1])); + parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); + parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT), QUERY); + return parser; + } + + private final String index; + private final QueryProvider queryProvider; + + public DataFrameAnalyticsSource(String index, @Nullable QueryProvider queryProvider) { + this.index = ExceptionsHelper.requireNonNull(index, INDEX); + if (index.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must be non-empty", INDEX); + } + this.queryProvider = queryProvider == null ? QueryProvider.defaultQuery() : queryProvider; + } + + public DataFrameAnalyticsSource(StreamInput in) throws IOException { + index = in.readString(); + queryProvider = QueryProvider.fromStream(in); + } + + public DataFrameAnalyticsSource(DataFrameAnalyticsSource other) { + this.index = other.index; + this.queryProvider = new QueryProvider(other.queryProvider); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(index); + queryProvider.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + builder.field(QUERY.getPreferredName(), queryProvider.getQuery()); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsSource other = (DataFrameAnalyticsSource) o; + return Objects.equals(index, other.index) + && Objects.equals(queryProvider, other.queryProvider); + } + + @Override + public int hashCode() { + return Objects.hash(index, queryProvider); + } + + public String getIndex() { + return index; + } + + /** + * Get the fully parsed query from the semi-parsed stored {@code Map} + * + * @return Fully parsed query + */ + public QueryBuilder getParsedQuery() { + Exception exception = queryProvider.getParsingException(); + if (exception != null) { + if (exception instanceof RuntimeException) { + throw (RuntimeException) exception; + } else { + throw new ElasticsearchException(queryProvider.getParsingException()); + } + } + return queryProvider.getParsedQuery(); + } + + Exception getQueryParsingException() { + return queryProvider.getParsingException(); + } + + /** + * Calls the parser and returns any gathered deprecations + * + * @param namedXContentRegistry XContent registry to transform the lazily parsed query + * @return The deprecations from parsing the query + */ + public List getQueryDeprecations(NamedXContentRegistry namedXContentRegistry) { + List deprecations = new ArrayList<>(); + try { + XContentObjectTransformer.queryBuilderTransformer(namedXContentRegistry).fromMap(queryProvider.getQuery(), + deprecations); + } catch (Exception exception) { + // Certain thrown exceptions wrap up the real Illegal argument making it hard to determine cause for the user + if (exception.getCause() instanceof IllegalArgumentException) { + exception = (Exception) exception.getCause(); + } + throw ExceptionsHelper.badRequestException(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, exception); + } + return deprecations; + } + + public Map getQuery() { + return queryProvider.getQuery(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index 072b7fac9d455..75a94e8ef4d7d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -27,6 +27,8 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -392,10 +394,21 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .field(TYPE, KEYWORD) .endObject() .startObject(DataFrameAnalyticsConfig.SOURCE.getPreferredName()) - .field(TYPE, KEYWORD) + .startObject(PROPERTIES) + .startObject(DataFrameAnalyticsSource.INDEX.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsSource.QUERY.getPreferredName()) + .field(ENABLED, false) + .endObject() + .endObject() .endObject() .startObject(DataFrameAnalyticsConfig.DEST.getPreferredName()) - .field(TYPE, KEYWORD) + .startObject(PROPERTIES) + .startObject(DataFrameAnalyticsDest.INDEX.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .endObject() .endObject() .startObject(DataFrameAnalyticsConfig.ANALYSES_FIELDS.getPreferredName()) .field(ENABLED, false) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 378f5da401d2b..6786d5eade59f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -9,6 +9,8 @@ import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -262,6 +264,9 @@ public final class ReservedFieldNames { DataFrameAnalyticsConfig.DEST.getPreferredName(), DataFrameAnalyticsConfig.ANALYSES.getPreferredName(), DataFrameAnalyticsConfig.ANALYSES_FIELDS.getPreferredName(), + DataFrameAnalyticsDest.INDEX.getPreferredName(), + DataFrameAnalyticsSource.INDEX.getPreferredName(), + DataFrameAnalyticsSource.QUERY.getPreferredName(), "outlier_detection", "method", "number_neighbours", diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index 100f960200c39..d45043714da98 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -24,15 +24,12 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.query.MatchAllQueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.AbstractSerializingTestCase; -import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -79,23 +76,14 @@ public static DataFrameAnalyticsConfig createRandom(String id) { } public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) { - String source = randomAlphaOfLength(10); - String dest = randomAlphaOfLength(10); + DataFrameAnalyticsSource source = DataFrameAnalyticsSourceTests.createRandom(); + DataFrameAnalyticsDest dest = DataFrameAnalyticsDestTests.createRandom(); List analyses = Collections.singletonList(DataFrameAnalysisConfigTests.randomConfig()); DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder() .setId(id) .setAnalyses(analyses) .setSource(source) .setDest(dest); - if (randomBoolean()) { - try { - builder.setQueryProvider( - QueryProvider.fromParsedQuery(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10)))); - } catch (IOException e) { - // Should never happen - throw new UncheckedIOException(e); - } - } if (randomBoolean()) { builder.setAnalysesFields(new FetchSourceContext(true, generateRandomStringArray(10, 10, false, false), @@ -114,20 +102,18 @@ public static String randomValidId() { private static final String ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS = "{\n" + " \"id\": \"old-data-frame\",\n" + - " \"source\": \"my-index\",\n" + - " \"dest\": \"dest-index\",\n" + - " \"analyses\": {\"outlier_detection\": {\"number_neighbours\": 10}},\n" + //query:match:type stopped being supported in 6.x - " \"query\": {\"match\" : {\"query\":\"fieldName\", \"type\": \"phrase\"}}\n" + + " \"source\": {\"index\":\"my-index\", \"query\": {\"match\" : {\"query\":\"fieldName\", \"type\": \"phrase\"}}},\n" + + " \"dest\": {\"index\":\"dest-index\"},\n" + + " \"analyses\": {\"outlier_detection\": {\"number_neighbours\": 10}}\n" + "}"; private static final String MODERN_QUERY_DATA_FRAME_ANALYTICS = "{\n" + " \"id\": \"data-frame\",\n" + - " \"source\": \"my-index\",\n" + - " \"dest\": \"dest-index\",\n" + - " \"analyses\": {\"outlier_detection\": {\"number_neighbours\": 10}},\n" + // match_all if parsed, adds default values in the options - " \"query\": {\"match_all\" : {}}\n" + + " \"source\": {\"index\":\"my-index\", \"query\": {\"match_all\" : {}}},\n" + + " \"dest\": {\"index\":\"dest-index\"},\n" + + " \"analyses\": {\"outlier_detection\": {\"number_neighbours\": 10}}\n" + "}"; public void testQueryConfigStoresUserInputOnly() throws IOException { @@ -137,7 +123,7 @@ public void testQueryConfigStoresUserInputOnly() throws IOException { MODERN_QUERY_DATA_FRAME_ANALYTICS)) { DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); - assertThat(config.getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); + assertThat(config.getSource().getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); } try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) @@ -146,7 +132,7 @@ public void testQueryConfigStoresUserInputOnly() throws IOException { MODERN_QUERY_DATA_FRAME_ANALYTICS)) { DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build(); - assertThat(config.getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); + assertThat(config.getSource().getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); } } @@ -157,7 +143,7 @@ public void testPastQueryConfigParse() throws IOException { ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); - ElasticsearchException e = expectThrows(ElasticsearchException.class, config::getParsedQuery); + ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> config.getSource().getParsedQuery()); assertEquals("[match] query doesn't support multiple fields, found [query] and [type]", e.getMessage()); } @@ -168,7 +154,7 @@ public void testPastQueryConfigParse() throws IOException { XContentParseException e = expectThrows(XContentParseException.class, () -> DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build()); - assertEquals("[6:64] [data_frame_analytics_config] failed to parse field [query]", e.getMessage()); + assertThat(e.getMessage(), containsString("[data_frame_analytics_config] failed to parse field [source]")); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java new file mode 100644 index 0000000000000..0e34be10a21ff --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase { + + @Override + protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsDest.createParser(false).apply(parser, null); + } + + @Override + protected DataFrameAnalyticsDest createTestInstance() { + return createRandom(); + } + + public static DataFrameAnalyticsDest createRandom() { + return new DataFrameAnalyticsDest(randomAlphaOfLength(10)); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataFrameAnalyticsDest::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java new file mode 100644 index 0000000000000..7783354d425a9 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; + +public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } + + @Override + protected DataFrameAnalyticsSource doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsSource.createParser(false).apply(parser, null); + } + + @Override + protected DataFrameAnalyticsSource createTestInstance() { + return createRandom(); + } + + public static DataFrameAnalyticsSource createRandom() { + String index = randomAlphaOfLength(10); + QueryProvider queryProvider = null; + if (randomBoolean()) { + try { + queryProvider = QueryProvider.fromParsedQuery(QueryBuilders.termQuery(randomAlphaOfLength(10), randomAlphaOfLength(10))); + } catch (IOException e) { + // Should never happen + throw new UncheckedIOException(e); + } + } + return new DataFrameAnalyticsSource(index, queryProvider); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataFrameAnalyticsSource::new; + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index b2dc5bc54f76e..66e95ad5b4681 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -41,7 +41,11 @@ integTestRunner { 'ml/data_frame_analytics_crud/Test put config with unknown top level field', 'ml/data_frame_analytics_crud/Test put config with unknown field in outlier detection analysis', 'ml/data_frame_analytics_crud/Test put config given missing source', + 'ml/data_frame_analytics_crud/Test put config given source with empty index', + 'ml/data_frame_analytics_crud/Test put config given source without index', 'ml/data_frame_analytics_crud/Test put config given missing dest', + 'ml/data_frame_analytics_crud/Test put config given dest with empty index', + 'ml/data_frame_analytics_crud/Test put config given dest without index', 'ml/data_frame_analytics_crud/Test put config given missing analyses', 'ml/data_frame_analytics_crud/Test put config given empty analyses', 'ml/data_frame_analytics_crud/Test put config given two analyses', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index 45395403a3a1a..7abb831462f17 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -16,6 +16,8 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.junit.After; @@ -76,7 +78,7 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { double scoreOfOutlier = 0.0; double scoreOfNonOutlier = -1.0; for (SearchHit hit : sourceData.getHits()) { - GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest()).setId(hit.getId()).get(); + GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); assertThat(destDocGetResponse.isExists(), is(true)); Map sourceDoc = hit.getSourceAsMap(); Map destDoc = destDocGetResponse.getSource(); @@ -132,11 +134,11 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { waitUntilAnalyticsIsStopped(id); // Check we've got all docs - SearchResponse searchResponse = client().prepareSearch(config.getDest()).setTrackTotalHits(true).get(); + SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); // Check they all have an outlier_score - searchResponse = client().prepareSearch(config.getDest()) + searchResponse = client().prepareSearch(config.getDest().getIndex()) .setTrackTotalHits(true) .setQuery(QueryBuilders.existsQuery("outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); @@ -144,8 +146,8 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String sourceIndex) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); - configBuilder.setSource(sourceIndex); - configBuilder.setDest(sourceIndex + "-results"); + configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); + configBuilder.setDest(new DataFrameAnalyticsDest(sourceIndex + "-results")); Map analysisConfig = new HashMap<>(); analysisConfig.put("outlier_detection", Collections.emptyMap()); configBuilder.setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig(analysisConfig))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java index 3be814a72fc2d..addabdb625553 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java @@ -88,11 +88,11 @@ protected void doExecute(Task task, PutDataFrameAnalyticsAction.Request request, if (licenseState.isAuthAllowed()) { final String username = securityContext.getUser().principal(); RoleDescriptor.IndicesPrivileges sourceIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() - .indices(memoryCappedConfig.getSource()) + .indices(memoryCappedConfig.getSource().getIndex()) .privileges("read") .build(); RoleDescriptor.IndicesPrivileges destIndexPrivileges = RoleDescriptor.IndicesPrivileges.builder() - .indices(memoryCappedConfig.getDest()) + .indices(memoryCappedConfig.getDest().getIndex()) .privileges("read", "index", "create_index") .build(); @@ -147,12 +147,6 @@ private void validateConfig(DataFrameAnalyticsConfig config) { throw ExceptionsHelper.badRequestException("id [{}] is too long; must not contain more than {} characters", config.getId(), MlStrings.ID_LENGTH_LIMIT); } - if (config.getSource().isEmpty()) { - throw ExceptionsHelper.badRequestException("[{}] must be non-empty", DataFrameAnalyticsConfig.SOURCE); - } - if (config.getDest().isEmpty()) { - throw ExceptionsHelper.badRequestException("[{}] must be non-empty", DataFrameAnalyticsConfig.DEST); - } DataFrameAnalysesUtils.readAnalyses(config.getAnalyses()); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 3f2f568c8d419..4319bcada0bbd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -109,7 +109,7 @@ public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState current ClientHelper.executeAsyncWithOrigin(client, ML_ORIGIN, DeleteIndexAction.INSTANCE, - new DeleteIndexRequest(config.getDest()), + new DeleteIndexRequest(config.getDest().getIndex()), ActionListener.wrap( r-> reindexingStateListener.onResponse(config), e -> { @@ -149,7 +149,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF ClientHelper.executeAsyncWithOrigin(client, ClientHelper.ML_ORIGIN, RefreshAction.INSTANCE, - new RefreshRequest(config.getDest()), + new RefreshRequest(config.getDest().getIndex()), refreshListener), task::markAsFailed ); @@ -158,10 +158,9 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF ActionListener copyIndexCreatedListener = ActionListener.wrap( createIndexResponse -> { ReindexRequest reindexRequest = new ReindexRequest(); - reindexRequest.setSourceIndices(config.getSource()); - // we default to match_all - reindexRequest.setSourceQuery(config.getParsedQuery()); - reindexRequest.setDestIndex(config.getDest()); + reindexRequest.setSourceIndices(config.getSource().getIndex()); + reindexRequest.setSourceQuery(config.getSource().getParsedQuery()); + reindexRequest.setDestIndex(config.getDest().getIndex()); reindexRequest.setScript(new Script("ctx._source." + DataFrameAnalyticsFields.ID + " = ctx._id")); final ThreadContext threadContext = client.threadPool().getThreadContext(); @@ -175,7 +174,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF reindexCompletedListener::onFailure ); - createDestinationIndex(config.getSource(), config.getDest(), config.getHeaders(), copyIndexCreatedListener); + createDestinationIndex(config.getSource().getIndex(), config.getDest().getIndex(), config.getHeaders(), copyIndexCreatedListener); } private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index ab8fb96f74e13..c827e0b48edd5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -102,10 +102,10 @@ public static void create(Client client, DataFrameAnalyticsConfig config, ActionListener listener) { Set resultFields = resolveResultsFields(config); - validateIndexAndExtractFields(client, config.getHeaders(), config.getDest(), config.getAnalysesFields(), resultFields, + validateIndexAndExtractFields(client, config.getHeaders(), config.getDest().getIndex(), config.getAnalysesFields(), resultFields, ActionListener.wrap( - extractedFields -> listener.onResponse( - new DataFrameDataExtractorFactory(client, config.getId(), config.getDest(), extractedFields, config.getHeaders())), + extractedFields -> listener.onResponse(new DataFrameDataExtractorFactory( + client, config.getId(), config.getDest().getIndex(), extractedFields, config.getHeaders())), listener::onFailure )); } @@ -121,10 +121,10 @@ public static void validateConfigAndSourceIndex(Client client, DataFrameAnalyticsConfig config, ActionListener listener) { Set resultFields = resolveResultsFields(config); - validateIndexAndExtractFields(client, config.getHeaders(), config.getSource(), config.getAnalysesFields(), resultFields, + validateIndexAndExtractFields(client, config.getHeaders(), config.getSource().getIndex(), config.getAnalysesFields(), resultFields, ActionListener.wrap( fields -> { - config.getParsedQuery(); // validate query is acceptable + config.getSource().getParsedQuery(); // validate query is acceptable listener.onResponse(config); }, listener::onFailure diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 5849b597de688..f4c97bdd909f0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -162,7 +162,7 @@ public Integer getProgressPercent(long allocationId) { private void refreshDest(DataFrameAnalyticsConfig config) { ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, - () -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest())).actionGet()); + () -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet()); } static class ProcessContext { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java index b7e23902ce2b2..5c782947a310c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java @@ -27,6 +27,8 @@ import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.ml.LocalStateMachineLearning; @@ -58,16 +60,16 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import static org.hamcrest.Matchers.equalTo; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class DataFrameAnalyticsManagerIT extends BaseMlIntegTestCase { @@ -110,7 +112,7 @@ public void testTaskContinuationFromReindexState() throws Exception { assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); // Check we've got all docs - SearchResponse searchResponse = client().prepareSearch(config.getDest()).setTrackTotalHits(true).get(); + SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get(); Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); for(SearchHit hit : searchResponse.getHits().getHits()) { Map src = hit.getSourceAsMap(); @@ -132,7 +134,7 @@ public void testTaskContinuationFromReindexStateWithPreviousResultsIndex() throw putDataFrameAnalyticsConfig(config); // Create the "results" index, as if we ran reindex already in the process, but did not transition from the state properly - createAnalysesResultsIndex(config.getDest(), false); + createAnalysesResultsIndex(config.getDest().getIndex(), false); List results = buildExpectedResults(sourceIndex); DataFrameAnalyticsManager manager = createManager(results); @@ -144,7 +146,7 @@ public void testTaskContinuationFromReindexStateWithPreviousResultsIndex() throw assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); // Check we've got all docs - SearchResponse searchResponse = client().prepareSearch(config.getDest()).setTrackTotalHits(true).get(); + SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get(); Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); for(SearchHit hit : searchResponse.getHits().getHits()) { Map src = hit.getSourceAsMap(); @@ -165,7 +167,7 @@ public void testTaskContinuationFromAnalyzeState() throws Exception { DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); putDataFrameAnalyticsConfig(config); // Create the "results" index to simulate running reindex already and having partially ran analysis - createAnalysesResultsIndex(config.getDest(), true); + createAnalysesResultsIndex(config.getDest().getIndex(), true); List results = buildExpectedResults(sourceIndex); DataFrameAnalyticsManager manager = createManager(results); @@ -177,7 +179,7 @@ public void testTaskContinuationFromAnalyzeState() throws Exception { assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); // Check we've got all docs - SearchResponse searchResponse = client().prepareSearch(config.getDest()).setTrackTotalHits(true).get(); + SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get(); Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); for(SearchHit hit : searchResponse.getHits().getHits()) { Map src = hit.getSourceAsMap(); @@ -201,8 +203,8 @@ private synchronized boolean isCompleted() { private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String sourceIndex) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); - configBuilder.setSource(sourceIndex); - configBuilder.setDest(sourceIndex + "-results"); + configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); + configBuilder.setDest(new DataFrameAnalyticsDest(sourceIndex + "-results")); Map analysisConfig = new HashMap<>(); analysisConfig.put("outlier_detection", Collections.emptyMap()); configBuilder.setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig(analysisConfig))); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index ed92f1423db19..38feb3fc58792 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -33,17 +33,21 @@ id: "simple-outlier-detection-with-query" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source", + "query": {"term" : { "user" : "Kimchy" }} + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}], - "query": {"term" : { "user" : "Kimchy" }}, "analyses_fields": [ "obj1.*", "obj2.*" ] } - match: { id: "simple-outlier-detection-with-query" } - - match: { source: "index-source" } - - match: { dest: "index-dest" } + - match: { source.index: "index-source" } + - match: { source.query: {"term" : { "user" : "Kimchy"} } } + - match: { dest.index: "index-dest" } - match: { analyses: [{"outlier_detection":{}}] } - - match: { query: {"term" : { "user" : "Kimchy"} } } - match: { analyses_fields: {"includes" : ["obj1.*", "obj2.*" ], "excludes": [] } } --- @@ -54,10 +58,14 @@ id: "data_frame_with_header" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source", + "query": {"term" : { "user" : "Kimchy" }} + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}], - "query": {"term" : { "user" : "Kimchy" }}, "headers":{ "a_security_header" : "secret" } } @@ -69,15 +77,19 @@ id: "simple-outlier-detection" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}] } - match: { id: "simple-outlier-detection" } - - match: { source: "index-source" } - - match: { dest: "index-dest" } + - match: { source.index: "index-source" } + - match: { source.query: {"match_all" : {} } } + - match: { dest.index: "index-dest" } - match: { analyses: [{"outlier_detection":{}}] } - - match: { query: {"match_all" : {} } } --- "Test put config with inconsistent body/param ids": @@ -89,8 +101,12 @@ body: > { "id": "body_id", - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}] } @@ -103,8 +119,12 @@ id: "this id contains spaces" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}] } @@ -117,8 +137,12 @@ id: "unknown_field" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}], "unknown_field": 42 } @@ -132,8 +156,12 @@ id: "unknown_field" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{"unknown_field": 42}}] } @@ -146,7 +174,44 @@ id: "simple-outlier-detection" body: > { - "dest": "index-dest", + "dest": { + "index": "index-dest" + }, + "analyses": [{"outlier_detection":{}}] + } + +--- +"Test put config given source with empty index": + + - do: + catch: /\[index\] must be non-empty/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": { + "index": "" + }, + "dest": { + "index": "index-dest" + }, + "analyses": [{"outlier_detection":{}}] + } + +--- +"Test put config given source without index": + + - do: + catch: /Required \[index\]/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": { + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}] } @@ -159,7 +224,44 @@ id: "simple-outlier-detection" body: > { - "source": "index-source", + "source": { + "index": "index-source" + }, + "analyses": [{"outlier_detection":{}}] + } + +--- +"Test put config given dest with empty index": + + - do: + catch: /\[index\] must be non-empty/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "" + }, + "analyses": [{"outlier_detection":{}}] + } + +--- +"Test put config given dest without index": + + - do: + catch: /Required \[index\]/ + ml.put_data_frame_analytics: + id: "simple-outlier-detection" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + }, "analyses": [{"outlier_detection":{}}] } @@ -172,8 +274,12 @@ id: "simple-outlier-detection" body: > { - "source": "index-source", - "dest": "index-dest" + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + } } --- @@ -185,8 +291,12 @@ id: "simple-outlier-detection" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, "analyses": [] } @@ -199,8 +309,12 @@ id: "simple-outlier-detection" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}, {"outlier_detection":{}}] } @@ -212,8 +326,12 @@ id: "foo-1" body: > { - "source": "index-foo-1_source", - "dest": "index-foo-1_dest", + "source": { + "index": "index-foo-1_source" + }, + "dest": { + "index": "index-foo-1_dest" + }, "analyses": [{"outlier_detection":{}}] } @@ -222,8 +340,12 @@ id: "foo-2" body: > { - "source": "index-foo-2_source", - "dest": "index-foo-2_dest", + "source": { + "index": "index-foo-2_source" + }, + "dest": { + "index": "index-foo-2_dest" + }, "analyses": [{"outlier_detection":{}}] } - match: { id: "foo-2" } @@ -233,8 +355,12 @@ id: "bar" body: > { - "source": "index-bar_source", - "dest": "index-bar_dest", + "source": { + "index": "index-bar_source" + }, + "dest": { + "index": "index-bar_dest" + }, "analyses": [{"outlier_detection":{}}] } - match: { id: "bar" } @@ -297,8 +423,12 @@ id: "foo-1" body: > { - "source": "index-foo-1_source", - "dest": "index-foo-1_dest", + "source": { + "index": "index-foo-1_source" + }, + "dest": { + "index": "index-foo-1_dest" + }, "analyses": [{"outlier_detection":{}}] } @@ -307,8 +437,12 @@ id: "foo-2" body: > { - "source": "index-foo-2_source", - "dest": "index-foo-2_dest", + "source": { + "index": "index-foo-2_source" + }, + "dest": { + "index": "index-foo-2_dest" + }, "analyses": [{"outlier_detection":{}}] } - match: { id: "foo-2" } @@ -318,8 +452,12 @@ id: "bar" body: > { - "source": "index-bar_source", - "dest": "index-bar_dest", + "source": { + "index": "index-bar_source" + }, + "dest": { + "index": "index-bar_dest" + }, "analyses": [{"outlier_detection":{}}] } - match: { id: "bar" } @@ -383,8 +521,12 @@ id: "foo" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}] } @@ -426,10 +568,14 @@ id: "simple-outlier-detection-with-query" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source", + "query": {"term" : { "user" : "Kimchy" }} + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}], - "query": {"term" : { "user" : "Kimchy" }}, "model_memory_limit": "8gb", "analyses_fields": [ "obj1.*", "obj2.*" ] } @@ -440,17 +586,21 @@ id: "simple-outlier-detection-with-query" body: > { - "source": "index-source", - "dest": "index-dest", + "source": { + "index": "index-source", + "query": {"term" : { "user" : "Kimchy" }} + }, + "dest": { + "index": "index-dest" + }, "analyses": [{"outlier_detection":{}}], - "query": {"term" : { "user" : "Kimchy" }}, "analyses_fields": [ "obj1.*", "obj2.*" ] } - match: { id: "simple-outlier-detection-with-query" } - - match: { source: "index-source" } - - match: { dest: "index-dest" } + - match: { source.index: "index-source" } + - match: { source.query: {"term" : { "user" : "Kimchy"} } } + - match: { dest.index: "index-dest" } - match: { analyses: [{"outlier_detection":{}}] } - - match: { query: {"term" : { "user" : "Kimchy"} } } - match: { analyses_fields: {"includes" : ["obj1.*", "obj2.*" ], "excludes": [] } } - match: { model_memory_limit: "20mb" } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml index 5042f80c2b7b8..8149565238730 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml @@ -13,8 +13,12 @@ id: "missing_index" body: > { - "source": "missing", - "dest": "missing-dest", + "source": { + "index": "missing" + }, + "dest": { + "index": "missing-dest" + }, "analyses": [{"outlier_detection":{}}] } @@ -35,8 +39,12 @@ id: "foo" body: > { - "source": "empty-index", - "dest": "empty-index-dest", + "source": { + "index": "empty-index" + }, + "dest": { + "index": "empty-index-dest" + }, "analyses": [{"outlier_detection":{}}] } From 6ec21624c6728aa36715567ce8059f0d1e46e91e Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 9 Apr 2019 15:22:52 +0300 Subject: [PATCH 30/67] [FEATURE][ML] Fix results joining when rows are skipped (#40906) --- .../process/AnalyticsResultProcessor.java | 10 +- .../process/DataFrameRowsJoiner.java | 151 +++++++++++------ .../AnalyticsResultProcessorTests.java | 3 + .../process/DataFrameRowsJoinerTests.java | 158 +++++++++++++++++- 4 files changed, 260 insertions(+), 62 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index b2721daf1e515..7e204e52e0c4a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -39,12 +39,12 @@ public void awaitForCompletion() { } public void process(AnalyticsProcess process) { - - try { + // TODO When java 9 features can be used, we will not need the local variable here + try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) { Iterator iterator = process.readAnalyticsResults(); while (iterator.hasNext()) { AnalyticsResult result = iterator.next(); - processResult(result); + processResult(result, resultsJoiner); } } catch (Exception e) { LOGGER.error("Error parsing data frame analytics output", e); @@ -54,10 +54,10 @@ public void process(AnalyticsProcess process) { } } - private void processResult(AnalyticsResult result) { + private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) { RowResults rowResults = result.getRowResults(); if (rowResults != null) { - dataFrameRowsJoiner.processRowResults(rowResults); + resultsJoiner.processRowResults(rowResults); } Integer progressPercent = result.getProgressPercent(); if (progressPercent != null) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java index a86645e4fc52d..ef943820374ea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoiner.java @@ -16,32 +16,39 @@ import org.elasticsearch.client.Client; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import java.io.IOException; -import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; import java.util.LinkedHashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; -class DataFrameRowsJoiner { +class DataFrameRowsJoiner implements AutoCloseable { private static final Logger LOGGER = LogManager.getLogger(DataFrameRowsJoiner.class); + private static final int RESULTS_BATCH_SIZE = 1000; + private final String analyticsId; private final Client client; private final DataFrameDataExtractor dataExtractor; - private List currentDataFrameRows; - private List currentResults; + private final Iterator dataFrameRowsIterator; + private LinkedList currentResults; private boolean failed; DataFrameRowsJoiner(String analyticsId, Client client, DataFrameDataExtractor dataExtractor) { this.analyticsId = Objects.requireNonNull(analyticsId); this.client = Objects.requireNonNull(client); this.dataExtractor = Objects.requireNonNull(dataExtractor); + this.dataFrameRowsIterator = new ResultMatchingDataFrameRows(); + this.currentResults = new LinkedList<>(); } void processRowResults(RowResults rowResults) { @@ -60,62 +67,24 @@ void processRowResults(RowResults rowResults) { } private void addResultAndJoinIfEndOfBatch(RowResults rowResults) { - if (currentDataFrameRows == null) { - Optional> nextBatch = getNextBatch(); - if (nextBatch.isPresent() == false) { - return; - } - currentDataFrameRows = nextBatch.get(); - currentResults = new ArrayList<>(currentDataFrameRows.size()); - } currentResults.add(rowResults); - if (currentResults.size() == currentDataFrameRows.size()) { + if (currentResults.size() == RESULTS_BATCH_SIZE) { joinCurrentResults(); - currentDataFrameRows = null; - } - } - - private Optional> getNextBatch() { - try { - return dataExtractor.next(); - } catch (IOException e) { - // TODO Implement recovery strategy or better error reporting - LOGGER.error("Error reading next batch of data frame rows", e); - return Optional.empty(); } } private void joinCurrentResults() { BulkRequest bulkRequest = new BulkRequest(); - for (int i = 0; i < currentDataFrameRows.size(); i++) { - DataFrameDataExtractor.Row row = currentDataFrameRows.get(i); - if (row.shouldSkip()) { - continue; - } - RowResults result = currentResults.get(i); + while (currentResults.isEmpty() == false) { + RowResults result = currentResults.pop(); + DataFrameDataExtractor.Row row = dataFrameRowsIterator.next(); checkChecksumsMatch(row, result); - - SearchHit hit = row.getHit(); - Map source = new LinkedHashMap(hit.getSourceAsMap()); - source.putAll(result.getResults()); - new IndexRequest(hit.getIndex()); - IndexRequest indexRequest = new IndexRequest(hit.getIndex()); - indexRequest.id(hit.getId()); - indexRequest.source(source); - indexRequest.opType(DocWriteRequest.OpType.INDEX); - bulkRequest.add(indexRequest); + bulkRequest.add(createIndexRequest(result, row.getHit())); } if (bulkRequest.numberOfActions() > 0) { - BulkResponse bulkResponse = - ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), - ClientHelper.ML_ORIGIN, - client, - () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()); - if (bulkResponse.hasFailures()) { - LOGGER.error("Failures while writing data frame"); - // TODO Better error handling - } + executeBulkRequest(bulkRequest); } + currentResults = new LinkedList<>(); } private void checkChecksumsMatch(DataFrameDataExtractor.Row row, RowResults result) { @@ -128,4 +97,88 @@ private void checkChecksumsMatch(DataFrameDataExtractor.Row row, RowResults resu // TODO Communicate this error to the user as effectively the analytics have failed (e.g. FAILED state, audit error, etc.) } } + + private IndexRequest createIndexRequest(RowResults result, SearchHit hit) { + Map source = new LinkedHashMap(hit.getSourceAsMap()); + source.putAll(result.getResults()); + IndexRequest indexRequest = new IndexRequest(hit.getIndex()); + indexRequest.id(hit.getId()); + indexRequest.source(source); + indexRequest.opType(DocWriteRequest.OpType.INDEX); + return indexRequest; + } + + private void executeBulkRequest(BulkRequest bulkRequest) { + BulkResponse bulkResponse = ClientHelper.executeWithHeaders(dataExtractor.getHeaders(), ClientHelper.ML_ORIGIN, client, + () -> client.execute(BulkAction.INSTANCE, bulkRequest).actionGet()); + if (bulkResponse.hasFailures()) { + LOGGER.error("Failures while writing data frame"); + // TODO Better error handling + } + } + + @Override + public void close() { + try { + joinCurrentResults(); + } catch (Exception e) { + LOGGER.error(new ParameterizedMessage("[{}] Failed to join results", analyticsId), e); + failed = true; + } finally { + try { + consumeDataExtractor(); + } catch (Exception e) { + LOGGER.error(new ParameterizedMessage("[{}] Failed to consume data extractor", analyticsId), e); + } + } + } + + private void consumeDataExtractor() throws IOException { + dataExtractor.cancel(); + while (dataExtractor.hasNext()) { + dataExtractor.next(); + } + } + + private class ResultMatchingDataFrameRows implements Iterator { + + private List currentDataFrameRows = Collections.emptyList(); + private int currentDataFrameRowsIndex; + + @Override + public boolean hasNext() { + return dataExtractor.hasNext() || currentDataFrameRowsIndex < currentDataFrameRows.size(); + } + + @Override + public DataFrameDataExtractor.Row next() { + DataFrameDataExtractor.Row row = null; + while ((row == null || row.shouldSkip()) && hasNext()) { + advanceToNextBatchIfNecessary(); + row = currentDataFrameRows.get(currentDataFrameRowsIndex++); + } + + if (row == null || row.shouldSkip()) { + throw ExceptionsHelper.serverError("No more data frame rows could be found while joining results"); + } + return row; + } + + private void advanceToNextBatchIfNecessary() { + if (currentDataFrameRowsIndex >= currentDataFrameRows.size()) { + currentDataFrameRows = getNextDataRowsBatch().orElse(Collections.emptyList()); + currentDataFrameRowsIndex = 0; + } + } + + private Optional> getNextDataRowsBatch() { + try { + return dataExtractor.next(); + } catch (IOException e) { + // TODO Implement recovery strategy or better error reporting + LOGGER.error("Error reading next batch of data frame rows", e); + return Optional.empty(); + } + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 716b96a615846..e3f4cf6ebc9f7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -16,6 +16,7 @@ import java.util.List; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -37,6 +38,7 @@ public void testProcess_GivenNoResults() { resultProcessor.process(process); resultProcessor.awaitForCompletion(); + verify(dataFrameRowsJoiner).close(); verifyNoMoreInteractions(dataFrameRowsJoiner); } @@ -47,6 +49,7 @@ public void testProcess_GivenEmptyResults() { resultProcessor.process(process); resultProcessor.awaitForCompletion(); + verify(dataFrameRowsJoiner).close(); Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java index a4795c0ad7cc6..fd2b396d62541 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/DataFrameRowsJoinerTests.java @@ -25,16 +25,20 @@ import org.mockito.ArgumentCaptor; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.IntStream; import static org.hamcrest.Matchers.equalTo; import static org.mockito.Matchers.same; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -44,14 +48,12 @@ public class DataFrameRowsJoinerTests extends ESTestCase { private Client client; private DataFrameDataExtractor dataExtractor; - private DataFrameRowsJoiner dataFrameRowsJoiner; private ArgumentCaptor bulkRequestCaptor = ArgumentCaptor.forClass(BulkRequest.class); @Before public void setUpMocks() { client = mock(Client.class); dataExtractor = mock(DataFrameDataExtractor.class); - dataFrameRowsJoiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor); } public void testProcess_GivenNoResults() { @@ -65,7 +67,7 @@ public void testProcess_GivenSingleRowAndResult() throws IOException { String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; String[] dataValues = {"42.0"}; DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); - givenSingleDataFrameBatch(Arrays.asList(row)); + givenDataFrameBatches(Arrays.asList(row)); Map resultFields = new HashMap<>(); resultFields.put("a", "1"); @@ -86,13 +88,37 @@ public void testProcess_GivenSingleRowAndResult() throws IOException { assertThat(indexedDocSource.get("b"), equalTo("2")); } + public void testProcess_GivenFullResultsBatch() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + List firstBatch = new ArrayList<>(1000); + IntStream.range(0, 1000).forEach(i -> firstBatch.add(newRow(newHit(dataDoc), dataValues, i))); + List secondBatch = new ArrayList<>(1); + secondBatch.add(newRow(newHit(dataDoc), dataValues, 1000)); + givenDataFrameBatches(firstBatch, secondBatch); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + List results = new ArrayList<>(1001); + IntStream.range(0, 1001).forEach(i -> results.add(new RowResults(i, resultFields))); + givenProcessResults(results); + + List capturedBulkRequests = bulkRequestCaptor.getAllValues(); + assertThat(capturedBulkRequests.size(), equalTo(2)); + assertThat(capturedBulkRequests.get(0).numberOfActions(), equalTo(1000)); + assertThat(capturedBulkRequests.get(1).numberOfActions(), equalTo(1)); + } + public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IOException { givenClientHasNoFailures(); String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; String[] dataValues = {"42.0"}; DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); - givenSingleDataFrameBatch(Arrays.asList(row)); + givenDataFrameBatches(Arrays.asList(row)); Map resultFields = new HashMap<>(); resultFields.put("a", "1"); @@ -103,13 +129,110 @@ public void testProcess_GivenSingleRowAndResultWithMismatchingIdHash() throws IO verifyNoMoreInteractions(client); } + public void testProcess_GivenSingleBatchWithSkippedRows() throws IOException { + givenClientHasNoFailures(); + + DataFrameDataExtractor.Row skippedRow = newRow(newHit("{}"), null, 1); + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row normalRow = newRow(newHit(dataDoc), dataValues, 2); + givenDataFrameBatches(Arrays.asList(skippedRow, normalRow)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + RowResults result = new RowResults(2, resultFields); + givenProcessResults(Arrays.asList(result)); + + List capturedBulkRequests = bulkRequestCaptor.getAllValues(); + assertThat(capturedBulkRequests.size(), equalTo(1)); + BulkRequest capturedBulkRequest = capturedBulkRequests.get(0); + assertThat(capturedBulkRequest.numberOfActions(), equalTo(1)); + IndexRequest indexRequest = (IndexRequest) capturedBulkRequest.requests().get(0); + Map indexedDocSource = indexRequest.sourceAsMap(); + assertThat(indexedDocSource.size(), equalTo(4)); + assertThat(indexedDocSource.get("f_1"), equalTo("foo")); + assertThat(indexedDocSource.get("f_2"), equalTo(42.0)); + assertThat(indexedDocSource.get("a"), equalTo("1")); + assertThat(indexedDocSource.get("b"), equalTo("2")); + } + + public void testProcess_GivenTwoBatchesWhereFirstEndsWithSkippedRow() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row normalRow1 = newRow(newHit(dataDoc), dataValues, 1); + DataFrameDataExtractor.Row normalRow2 = newRow(newHit(dataDoc), dataValues, 2); + DataFrameDataExtractor.Row skippedRow = newRow(newHit("{}"), null, 3); + DataFrameDataExtractor.Row normalRow3 = newRow(newHit(dataDoc), dataValues, 4); + givenDataFrameBatches(Arrays.asList(normalRow1, normalRow2, skippedRow), Arrays.asList(normalRow3)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + RowResults result1 = new RowResults(1, resultFields); + RowResults result2 = new RowResults(2, resultFields); + RowResults result3 = new RowResults(4, resultFields); + givenProcessResults(Arrays.asList(result1, result2, result3)); + + List capturedBulkRequests = bulkRequestCaptor.getAllValues(); + assertThat(capturedBulkRequests.size(), equalTo(1)); + BulkRequest capturedBulkRequest = capturedBulkRequests.get(0); + assertThat(capturedBulkRequest.numberOfActions(), equalTo(3)); + IndexRequest indexRequest = (IndexRequest) capturedBulkRequest.requests().get(0); + Map indexedDocSource = indexRequest.sourceAsMap(); + assertThat(indexedDocSource.size(), equalTo(4)); + assertThat(indexedDocSource.get("f_1"), equalTo("foo")); + assertThat(indexedDocSource.get("f_2"), equalTo(42.0)); + assertThat(indexedDocSource.get("a"), equalTo("1")); + assertThat(indexedDocSource.get("b"), equalTo("2")); + } + + public void testProcess_GivenMoreResultsThanRows() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row row = newRow(newHit(dataDoc), dataValues, 1); + givenDataFrameBatches(Arrays.asList(row)); + + Map resultFields = new HashMap<>(); + resultFields.put("a", "1"); + resultFields.put("b", "2"); + RowResults result1 = new RowResults(1, resultFields); + RowResults result2 = new RowResults(2, resultFields); + givenProcessResults(Arrays.asList(result1, result2)); + + verifyNoMoreInteractions(client); + } + + public void testProcess_GivenNoResults_ShouldCancelAndConsumeExtractor() throws IOException { + givenClientHasNoFailures(); + + String dataDoc = "{\"f_1\": \"foo\", \"f_2\": 42.0}"; + String[] dataValues = {"42.0"}; + DataFrameDataExtractor.Row row1 = newRow(newHit(dataDoc), dataValues, 1); + DataFrameDataExtractor.Row row2 = newRow(newHit(dataDoc), dataValues, 1); + givenDataFrameBatches(Arrays.asList(row1), Arrays.asList(row2)); + + givenProcessResults(Collections.emptyList()); + + verifyNoMoreInteractions(client); + verify(dataExtractor).cancel(); + verify(dataExtractor, times(2)).next(); + } + private void givenProcessResults(List results) { - results.forEach(dataFrameRowsJoiner::processRowResults); + try (DataFrameRowsJoiner joiner = new DataFrameRowsJoiner(ANALYTICS_ID, client, dataExtractor)) { + results.forEach(joiner::processRowResults); + } } - private void givenSingleDataFrameBatch(List batch) throws IOException { - when(dataExtractor.hasNext()).thenReturn(true).thenReturn(true).thenReturn(false); - when(dataExtractor.next()).thenReturn(Optional.of(batch)).thenReturn(Optional.empty()); + private void givenDataFrameBatches(List... batches) throws IOException { + DelegateStubDataExtractor delegateStubDataExtractor = new DelegateStubDataExtractor(Arrays.asList(batches)); + when(dataExtractor.hasNext()).thenAnswer(a -> delegateStubDataExtractor.hasNext()); + when(dataExtractor.next()).thenAnswer(a -> delegateStubDataExtractor.next()); } private static SearchHit newHit(String json) { @@ -123,6 +246,7 @@ private static DataFrameDataExtractor.Row newRow(SearchHit hit, String[] values, when(row.getHit()).thenReturn(hit); when(row.getValues()).thenReturn(values); when(row.getChecksum()).thenReturn(checksum); + when(row.shouldSkip()).thenReturn(values == null); return row; } @@ -135,4 +259,22 @@ private void givenClientHasNoFailures() { when(client.execute(same(BulkAction.INSTANCE), bulkRequestCaptor.capture())).thenReturn(responseFuture); when(client.threadPool()).thenReturn(threadPool); } + + private static class DelegateStubDataExtractor { + + private final List> batches; + private int batchIndex; + + private DelegateStubDataExtractor(List> batches) { + this.batches = batches; + } + + public boolean hasNext() { + return batchIndex < batches.size(); + } + + public Optional> next() { + return Optional.of(batches.get(batchIndex++)); + } + } } From 0bf8ede18492c7a350eed1c29b8984fd95069ea1 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 11 Apr 2019 12:30:33 +0300 Subject: [PATCH 31/67] [FEATURE][ML] Remove DataFrameAnalyticsManagerIT After recent changes in build security and the featureAwareTast are now requiring asm 7.1. At the same time painless depends on asm 5.1. Unfortunately, this means `DataFrameAnalyticsManagerIT` is causing jar hell. In this commit I'm removing the test. The test is valuable and the intention is to bring it back once we've sorted out the jar mess. But we need to merge master into the feature branch to move on. --- x-pack/plugin/ml/build.gradle | 1 - .../DataFrameAnalyticsManagerIT.java | 412 ------------------ 2 files changed, 413 deletions(-) delete mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java diff --git a/x-pack/plugin/ml/build.gradle b/x-pack/plugin/ml/build.gradle index f5fe66501c3dc..6ca1a44c145da 100644 --- a/x-pack/plugin/ml/build.gradle +++ b/x-pack/plugin/ml/build.gradle @@ -40,7 +40,6 @@ dependencies { compileOnly project(path: xpackModule('core'), configuration: 'default') compileOnly "org.elasticsearch.plugin:elasticsearch-scripting-painless-spi:${versions.elasticsearch}" testCompile project(path: xpackModule('core'), configuration: 'testArtifacts') - testCompile project(':modules:lang-painless') // This should not be here testCompile project(path: xpackModule('security'), configuration: 'testArtifacts') diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java deleted file mode 100644 index 5c782947a310c..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/DataFrameAnalyticsManagerIT.java +++ /dev/null @@ -1,412 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.integration; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.bulk.BulkRequestBuilder; -import org.elasticsearch.action.bulk.BulkResponse; -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.analysis.common.CommonAnalysisPlugin; -import org.elasticsearch.client.node.NodeClient; -import org.elasticsearch.env.TestEnvironment; -import org.elasticsearch.index.reindex.ReindexPlugin; -import org.elasticsearch.painless.PainlessPlugin; -import org.elasticsearch.persistent.PersistentTasksCustomMetaData; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.transport.Netty4Plugin; -import org.elasticsearch.xpack.core.XPackClientPlugin; -import org.elasticsearch.xpack.core.ml.MlTasks; -import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; -import org.elasticsearch.xpack.ml.LocalStateMachineLearning; -import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask; -import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsFields; -import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; -import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcess; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessConfig; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessFactory; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsProcessManager; -import org.elasticsearch.xpack.ml.dataframe.process.AnalyticsResult; -import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; -import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; -import org.junit.Assert; -import org.junit.Before; - -import java.io.IOException; -import java.time.ZonedDateTime; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.TimeUnit; - -import static org.hamcrest.Matchers.equalTo; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyLong; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class DataFrameAnalyticsManagerIT extends BaseMlIntegTestCase { - - private volatile boolean finished; - private DataFrameAnalyticsConfigProvider provider; - private static double EXPECTED_OUTLIER_SCORE = 42.0; - @Before - public void fieldSetup() { - provider = new DataFrameAnalyticsConfigProvider(client()); - finished = false; - } - - @Override - protected Collection> nodePlugins() { - return Arrays.asList(LocalStateMachineLearning.class, CommonAnalysisPlugin.class, - ReindexPlugin.class, PainlessPlugin.class); - } - - @Override - protected Collection> transportClientPlugins() { - return Arrays.asList(XPackClientPlugin.class, Netty4Plugin.class, PainlessPlugin.class, ReindexPlugin.class); - } - - public void testTaskContinuationFromReindexState() throws Exception { - internalCluster().ensureAtLeastNumDataNodes(1); - ensureStableCluster(1); - String sourceIndex = "test-outlier-detection-from-reindex-state"; - createIndexForAnalysis(sourceIndex); - String id = "test_outlier_detection_from_reindex_state"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); - putDataFrameAnalyticsConfig(config); - List results = buildExpectedResults(sourceIndex); - - DataFrameAnalyticsManager manager = createManager(results); - - DataFrameAnalyticsTask task = buildMockedTask(config.getId()); - manager.execute(task, DataFrameAnalyticsState.REINDEXING); - - // wait for markAsCompleted() or markAsFailed() to be called; - assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); - - // Check we've got all docs - SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get(); - Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); - for(SearchHit hit : searchResponse.getHits().getHits()) { - Map src = hit.getSourceAsMap(); - assertNotNull(src.get("outlier_score")); - assertThat(src.get("outlier_score"), equalTo(EXPECTED_OUTLIER_SCORE)); - } - - verify(task, never()).markAsFailed(any(Exception.class)); - verify(task, times(1)).markAsCompleted(); - } - - public void testTaskContinuationFromReindexStateWithPreviousResultsIndex() throws Exception { - internalCluster().ensureAtLeastNumDataNodes(1); - ensureStableCluster(1); - String sourceIndex = "test-outlier-detection-from-reindex-state-with-results"; - createIndexForAnalysis(sourceIndex); - String id = "test_outlier_detection_from_reindex_state_with_results"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); - putDataFrameAnalyticsConfig(config); - - // Create the "results" index, as if we ran reindex already in the process, but did not transition from the state properly - createAnalysesResultsIndex(config.getDest().getIndex(), false); - List results = buildExpectedResults(sourceIndex); - - DataFrameAnalyticsManager manager = createManager(results); - - DataFrameAnalyticsTask task = buildMockedTask(config.getId()); - manager.execute(task, DataFrameAnalyticsState.REINDEXING); - - // wait for markAsCompleted() or markAsFailed() to be called; - assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); - - // Check we've got all docs - SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get(); - Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); - for(SearchHit hit : searchResponse.getHits().getHits()) { - Map src = hit.getSourceAsMap(); - assertNotNull(src.get("outlier_score")); - assertThat(src.get("outlier_score"), equalTo(EXPECTED_OUTLIER_SCORE)); - } - - verify(task, never()).markAsFailed(any(Exception.class)); - verify(task, times(1)).markAsCompleted(); - } - - public void testTaskContinuationFromAnalyzeState() throws Exception { - internalCluster().ensureAtLeastNumDataNodes(1); - ensureStableCluster(1); - String sourceIndex = "test-outlier-detection-from-analyze-state"; - createIndexForAnalysis(sourceIndex); - String id = "test_outlier_detection_from_analyze_state"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); - putDataFrameAnalyticsConfig(config); - // Create the "results" index to simulate running reindex already and having partially ran analysis - createAnalysesResultsIndex(config.getDest().getIndex(), true); - List results = buildExpectedResults(sourceIndex); - - DataFrameAnalyticsManager manager = createManager(results); - - DataFrameAnalyticsTask task = buildMockedTask(config.getId()); - manager.execute(task, DataFrameAnalyticsState.ANALYZING); - - // wait for markAsCompleted() or markAsFailed() to be called; - assertBusy(() -> assertTrue(isCompleted()), 120, TimeUnit.SECONDS); - - // Check we've got all docs - SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get(); - Assert.assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) 5)); - for(SearchHit hit : searchResponse.getHits().getHits()) { - Map src = hit.getSourceAsMap(); - assertNotNull(src.get("outlier_score")); - assertThat(src.get("outlier_score"), equalTo(EXPECTED_OUTLIER_SCORE)); - } - - verify(task, never()).markAsFailed(any(Exception.class)); - verify(task, times(1)).markAsCompleted(); - // Need to verify that we did not reindex again, as we already had the full destination index - verify(task, never()).setReindexingTaskId(anyLong()); - } - - private synchronized void completed() { - finished = true; - } - - private synchronized boolean isCompleted() { - return finished; - } - - private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String sourceIndex) { - DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); - configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); - configBuilder.setDest(new DataFrameAnalyticsDest(sourceIndex + "-results")); - Map analysisConfig = new HashMap<>(); - analysisConfig.put("outlier_detection", Collections.emptyMap()); - configBuilder.setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig(analysisConfig))); - return configBuilder.build(); - } - - @SuppressWarnings("unchecked") - private void putDataFrameAnalyticsConfig(DataFrameAnalyticsConfig config) throws Exception { - PlainActionFuture future = new PlainActionFuture(); - provider.put(config, Collections.emptyMap(), future); - future.get(); - } - - private void createIndexForAnalysis(String indexName) { - client().admin().indices().prepareCreate(indexName) - .addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical_1", "type=keyword") - .get(); - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); - bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - for (int i = 0; i < 5; i++) { - IndexRequest indexRequest = new IndexRequest(indexName); - - // We insert one odd value out of 5 for one feature - String docId = i == 0 ? "outlier" : "normal" + i; - indexRequest.id(docId); - indexRequest.source("numeric_1", i == 0 ? 100.0 : 1.0, "numeric_2", 1.0, "categorical_1", "foo_" + i); - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - Assert.fail("Failed to index data: " + bulkResponse.buildFailureMessage()); - } - } - - private void createAnalysesResultsIndex(String indexName, boolean includeOutlierScore) { - client().admin().indices().prepareCreate(indexName) - .addMapping("_doc", - "numeric_1", "type=double", - "numeric_2", "type=float", - "categorical_1", "type=keyword", - DataFrameAnalyticsFields.ID, "type=keyword") - .get(); - - BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); - bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - - for (int i = 0; i < 5; i++) { - IndexRequest indexRequest = new IndexRequest(indexName); - - // We insert one odd value out of 5 for one feature - String docId = i == 0 ? "outlier" : "normal" + i; - indexRequest.id(docId); - if (includeOutlierScore) { - indexRequest.source("numeric_1", i == 0 ? 100.0 : 1.0, - "numeric_2", 1.0, - "categorical_1", "foo_" + i, - DataFrameAnalyticsFields.ID, docId, - "outlier_score", 10.0); // simply needs to be a score different than expected - } else { - indexRequest.source("numeric_1", i == 0 ? 100.0 : 1.0, - "numeric_2", 1.0, - "categorical_1", "foo_" + i, - DataFrameAnalyticsFields.ID, docId); - } - bulkRequestBuilder.add(indexRequest); - } - BulkResponse bulkResponse = bulkRequestBuilder.get(); - if (bulkResponse.hasFailures()) { - Assert.fail("Failed to index data: " + bulkResponse.buildFailureMessage()); - } - } - - private List buildExpectedResults(String index) throws Exception { - SearchHit[] hits = client().search(new SearchRequest(index)).get().getHits().getHits(); - Arrays.sort(hits, Comparator.comparing(SearchHit::getId)); - List results = new ArrayList<>(hits.length); - for (SearchHit hit : hits) { - String[] fields = new String[2]; - Map src = hit.getSourceAsMap(); - fields[0] = src.get("numeric_1").toString(); - fields[1] = src.get("numeric_2").toString(); - results.add(new AnalyticsResult(new RowResults(Arrays.hashCode(fields), - Collections.singletonMap("outlier_score", EXPECTED_OUTLIER_SCORE)),null)); - } - return results; - } - - private DataFrameAnalyticsManager createManager(List expectedResults) { - AnalyticsProcessFactory factory = new MockedAnalyticsFactory(expectedResults); - AnalyticsProcessManager processManager = new AnalyticsProcessManager(client(), - TestEnvironment.newEnvironment(internalCluster().getDefaultSettings()), - client().threadPool(), - factory); - return new DataFrameAnalyticsManager(clusterService(), (NodeClient)internalCluster().dataNodeClient(), provider, processManager); - } - - @SuppressWarnings("unchecked") - private DataFrameAnalyticsTask buildMockedTask(String id) { - StartDataFrameAnalyticsAction.TaskParams params = new StartDataFrameAnalyticsAction.TaskParams(id); - DataFrameAnalyticsTask task = mock(DataFrameAnalyticsTask.class); - when(task.getParams()).thenReturn(params); - when(task.getAllocationId()).thenReturn(1L); - doAnswer(invoked -> { - client().threadPool().executor("listener").execute(() -> { - ActionListener listener = (ActionListener) invoked.getArguments()[1]; - final PersistentTasksCustomMetaData.PersistentTask resp = new PersistentTasksCustomMetaData.PersistentTask<>(id, - MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, - params, - 1, - new PersistentTasksCustomMetaData.Assignment(null, "none")); - listener.onResponse(resp); - }); - return null; - }).when(task).updatePersistentTaskState(any(DataFrameAnalyticsTaskState.class), any(ActionListener.class)); - doNothing().when(task).setReindexingTaskId(anyLong()); - doAnswer(invoked -> { - completed(); - return null; - }).when(task).markAsCompleted(); - doAnswer(invoked -> { - completed(); - Exception e = (Exception)invoked.getArguments()[0]; - fail(e.getMessage()); - return null; - }).when(task).markAsFailed(any(Exception.class)); - return task; - } - - class MockedAnalyticsFactory implements AnalyticsProcessFactory { - final List results; - - MockedAnalyticsFactory(List resultsToSupply) { - this.results = resultsToSupply; - } - @Override - public AnalyticsProcess createAnalyticsProcess(String jobId, - AnalyticsProcessConfig analyticsProcessConfig, - ExecutorService executorService) { - return new MockedAnalyticsProcess(results); - } - } - - class MockedAnalyticsProcess implements AnalyticsProcess { - - final List results; - final ZonedDateTime start; - MockedAnalyticsProcess(List resultsToSupply) { - results = resultsToSupply; - start = ZonedDateTime.now(); - } - - @Override - public void writeEndOfDataMessage() throws IOException { } - - @Override - public Iterator readAnalyticsResults() { - return results.iterator(); - } - - @Override - public void consumeAndCloseOutputStream() { } - - @Override - public boolean isReady() { - return true; - } - - @Override - public void writeRecord(String[] record) throws IOException { } - - @Override - public void persistState() throws IOException { } - - @Override - public void flushStream() throws IOException { } - - @Override - public void kill() throws IOException { } - - @Override - public ZonedDateTime getProcessStartTime() { - return start; - } - - @Override - public boolean isProcessAlive() { - return true; - } - - @Override - public boolean isProcessAliveAfterWaiting() { - return false; - } - - @Override - public String readError() { - return null; - } - - @Override - public void close() throws IOException { } - } -} From d230a07654c1e2471ef62d178699244617cb2ff6 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 11 Apr 2019 16:52:01 +0300 Subject: [PATCH 32/67] [FEATURE][ML] Add setting for results field name (#40927) Adds a field in the config `dest.results_field` which defaults to `ml`. Results will be written into an object that is named after the results field. In addition, upon starting the analytics there is now validation that the source index does not already have a field matching the results field name. This allows composability of different analytics runs. --- .../ml/dataframe/DataFrameAnalyticsDest.java | 23 ++- .../persistence/ElasticsearchMappings.java | 3 + .../ml/job/results/ReservedFieldNames.java | 1 + .../DataFrameAnalyticsDestTests.java | 4 +- .../integration/RunDataFrameAnalyticsIT.java | 17 ++- .../dataframe/DataFrameAnalyticsManager.java | 8 +- .../dataframe/analyses/DataFrameAnalysis.java | 8 - .../dataframe/analyses/OutlierDetection.java | 15 -- .../DataFrameDataExtractorFactory.java | 127 ++-------------- .../extractor/ExtractedFieldsDetector.java | 143 ++++++++++++++++++ .../process/AnalyticsProcessConfig.java | 8 +- .../process/AnalyticsProcessManager.java | 2 +- ...java => ExtractedFieldsDetectorTests.java} | 128 +++++++++++----- 13 files changed, 289 insertions(+), 198 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java rename x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/{DataFrameDataExtractorFactoryTests.java => ExtractedFieldsDetectorTests.java} (59%) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java index 3f3c2636ed3c2..98f1bdeb19189 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -20,40 +21,50 @@ public class DataFrameAnalyticsDest implements Writeable, ToXContentObject { public static final ParseField INDEX = new ParseField("index"); + public static final ParseField RESULTS_FIELD = new ParseField("results_field"); + + private static final String DEFAULT_RESULTS_FIELD = "ml"; public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { ConstructingObjectParser parser = new ConstructingObjectParser<>("data_frame_analytics_dest", - ignoreUnknownFields, a -> new DataFrameAnalyticsDest((String) a[0])); + ignoreUnknownFields, a -> new DataFrameAnalyticsDest((String) a[0], (String) a[1])); parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD); return parser; } private final String index; + private final String resultsField; - public DataFrameAnalyticsDest(String index) { + public DataFrameAnalyticsDest(String index, @Nullable String resultsField) { this.index = ExceptionsHelper.requireNonNull(index, INDEX); if (index.isEmpty()) { throw ExceptionsHelper.badRequestException("[{}] must be non-empty", INDEX); } + this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; } public DataFrameAnalyticsDest(StreamInput in) throws IOException { index = in.readString(); + resultsField = in.readString(); } public DataFrameAnalyticsDest(DataFrameAnalyticsDest other) { this.index = other.index; + this.resultsField = other.resultsField; } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(index); + out.writeString(resultsField); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(INDEX.getPreferredName(), index); + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); builder.endObject(); return builder; } @@ -64,15 +75,19 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; DataFrameAnalyticsDest other = (DataFrameAnalyticsDest) o; - return Objects.equals(index, other.index); + return Objects.equals(index, other.index) && Objects.equals(resultsField, other.resultsField); } @Override public int hashCode() { - return Objects.hash(index); + return Objects.hash(index, resultsField); } public String getIndex() { return index; } + + public String getResultsField() { + return resultsField; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index 75a94e8ef4d7d..c33b2f2943b6d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -408,6 +408,9 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(DataFrameAnalyticsDest.INDEX.getPreferredName()) .field(TYPE, KEYWORD) .endObject() + .startObject(DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() .endObject() .endObject() .startObject(DataFrameAnalyticsConfig.ANALYSES_FIELDS.getPreferredName()) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 6786d5eade59f..7530c120b0fd2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -265,6 +265,7 @@ public final class ReservedFieldNames { DataFrameAnalyticsConfig.ANALYSES.getPreferredName(), DataFrameAnalyticsConfig.ANALYSES_FIELDS.getPreferredName(), DataFrameAnalyticsDest.INDEX.getPreferredName(), + DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(), DataFrameAnalyticsSource.INDEX.getPreferredName(), DataFrameAnalyticsSource.QUERY.getPreferredName(), "outlier_detection", diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java index 0e34be10a21ff..7332687723805 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -24,7 +24,9 @@ protected DataFrameAnalyticsDest createTestInstance() { } public static DataFrameAnalyticsDest createRandom() { - return new DataFrameAnalyticsDest(randomAlphaOfLength(10)); + String index = randomAlphaOfLength(10); + String resultsField = randomBoolean() ? null : randomAlphaOfLength(10); + return new DataFrameAnalyticsDest(index, resultsField); } @Override diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index 7abb831462f17..cce6b765d388a 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.Nullable; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; @@ -65,7 +66,7 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { } String id = "test_outlier_detection_with_few_docs"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex, null); registerAnalytics(config); putAnalytics(config); @@ -86,8 +87,10 @@ public void testOutlierDetectionWithFewDocuments() throws Exception { assertThat(destDoc.containsKey(field), is(true)); assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); } - assertThat(destDoc.containsKey("outlier_score"), is(true)); - double outlierScore = (double) destDoc.get("outlier_score"); + assertThat(destDoc.containsKey("ml"), is(true)); + Map resultsObject = (Map) destDoc.get("ml"); + assertThat(resultsObject.containsKey("outlier_score"), is(true)); + double outlierScore = (double) resultsObject.get("outlier_score"); assertThat(outlierScore, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(100.0))); if (hit.getId().equals("outlier")) { scoreOfOutlier = outlierScore; @@ -124,7 +127,7 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { } String id = "test_outlier_detection_with_enough_docs_to_scroll"; - DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex); + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex, "custom_ml"); registerAnalytics(config); putAnalytics(config); @@ -140,14 +143,14 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { // Check they all have an outlier_score searchResponse = client().prepareSearch(config.getDest().getIndex()) .setTrackTotalHits(true) - .setQuery(QueryBuilders.existsQuery("outlier_score")).get(); + .setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get(); assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); } - private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String sourceIndex) { + private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String sourceIndex, @Nullable String resultsField) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); - configBuilder.setDest(new DataFrameAnalyticsDest(sourceIndex + "-results")); + configBuilder.setDest(new DataFrameAnalyticsDest(sourceIndex + "-results", resultsField)); Map analysisConfig = new HashMap<>(); analysisConfig.put("outlier_detection", Collections.emptyMap()); configBuilder.setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig(analysisConfig))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 4319bcada0bbd..02369adfd785d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -101,7 +101,7 @@ public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState current // The task has fully reindexed the documents and we should continue on with our analyses case ANALYZING: // TODO apply previously stored model state if applicable - startAnalytics(task, config); + startAnalytics(task, config, true); break; // If we are already at REINDEXING, we are not 100% sure if we reindexed ALL the docs. // We will delete the destination index, recreate, reindex @@ -139,7 +139,7 @@ public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState current private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) { // Reindexing is complete; start analytics ActionListener refreshListener = ActionListener.wrap( - refreshResponse -> startAnalytics(task, config), + refreshResponse -> startAnalytics(task, config, false), task::markAsFailed ); @@ -177,7 +177,7 @@ private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataF createDestinationIndex(config.getSource().getIndex(), config.getDest().getIndex(), config.getHeaders(), copyIndexCreatedListener); } - private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) { + private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, boolean isTaskRestarting) { // Update state to ANALYZING and start process ActionListener dataExtractorFactoryListener = ActionListener.wrap( dataExtractorFactory -> { @@ -201,7 +201,7 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi // TODO This could fail with errors. In that case we get stuck with the copied index. // We could delete the index in case of failure or we could try building the factory before reindexing // to catch the error early on. - DataFrameDataExtractorFactory.create(client, config, dataExtractorFactoryListener); + DataFrameDataExtractorFactory.create(client, config, isTaskRestarting, dataExtractorFactoryListener); } private void createDestinationIndex(String sourceIndex, String destinationIndex, Map headers, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java index 03139ac8c9edc..9fdd093fa324e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java @@ -10,7 +10,6 @@ import java.util.Locale; import java.util.Map; -import java.util.Set; public interface DataFrameAnalysis extends ToXContentObject { @@ -33,13 +32,6 @@ public String toString() { Type getType(); - /** - * The fields that will contain the results of the analysis - * - * @return Set of Strings representing the result fields for the constructed analysis - */ - Set getResultFields(); - interface Factory { /** diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java index 3b93f373546a2..47f614ba658f6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java @@ -5,12 +5,9 @@ */ package org.elasticsearch.xpack.ml.dataframe.analyses; -import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashSet; import java.util.Locale; import java.util.Map; -import java.util.Set; public class OutlierDetection extends AbstractDataFrameAnalysis { @@ -30,13 +27,6 @@ public String toString() { public static final String NUMBER_NEIGHBOURS = "number_neighbours"; public static final String METHOD = "method"; - private static final Set RESULT_FIELDS; - static { - Set set = new LinkedHashSet<>(); - set.add("outlier_score"); - RESULT_FIELDS = Collections.unmodifiableSet(set); - } - private final Integer numberNeighbours; private final Method method; @@ -62,11 +52,6 @@ protected Map getParams() { return params; } - @Override - public Set getResultFields() { - return RESULT_FIELDS; - } - static class Factory implements DataFrameAnalysis.Factory { @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index c827e0b48edd5..f7fc0faf0b011 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -7,60 +7,22 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.fieldcaps.FieldCapabilities; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.client.Client; -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.regex.Regex; import org.elasticsearch.index.IndexNotFoundException; -import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.job.messages.Messages; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.NameResolver; -import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; -import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; -import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysis; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; public class DataFrameDataExtractorFactory { - /** - * Fields to ignore. These are mostly internal meta fields. - */ - private static final List IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", - "_source", "_type", "_uid", "_version", "_feature", "_ignored"); - - /** - * The types supported by data frames - */ - private static final Set COMPATIBLE_FIELD_TYPES; - - static { - Set compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values()) - .map(NumberFieldMapper.NumberType::typeName) - .collect(Collectors.toSet()); - compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module - - COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes); - } - private final Client client; private final String analyticsId; private final String index; @@ -96,15 +58,15 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) { * * @param client ES Client used to make calls against the cluster * @param config The config from which to create the extractor factory + * @param isTaskRestarting Whether the task is restarting * @param listener The listener to notify on creation or failure */ public static void create(Client client, DataFrameAnalyticsConfig config, + boolean isTaskRestarting, ActionListener listener) { - Set resultFields = resolveResultsFields(config); - validateIndexAndExtractFields(client, config.getHeaders(), config.getDest().getIndex(), config.getAnalysesFields(), resultFields, - ActionListener.wrap( - extractedFields -> listener.onResponse(new DataFrameDataExtractorFactory( + validateIndexAndExtractFields(client, config.getDest().getIndex(), config, isTaskRestarting, + ActionListener.wrap(extractedFields -> listener.onResponse(new DataFrameDataExtractorFactory( client, config.getId(), config.getDest().getIndex(), extractedFields, config.getHeaders())), listener::onFailure )); @@ -120,9 +82,7 @@ public static void create(Client client, public static void validateConfigAndSourceIndex(Client client, DataFrameAnalyticsConfig config, ActionListener listener) { - Set resultFields = resolveResultsFields(config); - validateIndexAndExtractFields(client, config.getHeaders(), config.getSource().getIndex(), config.getAnalysesFields(), resultFields, - ActionListener.wrap( + validateIndexAndExtractFields(client, config.getSource().getIndex(), config, false, ActionListener.wrap( fields -> { config.getSource().getParsedQuery(); // validate query is acceptable listener.onResponse(config); @@ -131,80 +91,15 @@ public static void validateConfigAndSourceIndex(Client client, )); } - // Visible for testing - static ExtractedFields detectExtractedFields(String index, - FetchSourceContext desiredFields, - Set resultFields, - FieldCapabilitiesResponse fieldCapabilitiesResponse) { - Set fields = fieldCapabilitiesResponse.get().keySet(); - fields.removeAll(IGNORE_FIELDS); - // TODO a better solution may be to have some sort of known prefix and filtering that - fields.removeAll(resultFields); - removeFieldsWithIncompatibleTypes(fields, fieldCapabilitiesResponse); - includeAndExcludeFields(fields, desiredFields, index); - List sortedFields = new ArrayList<>(fields); - // We sort the fields to ensure the checksum for each document is deterministic - Collections.sort(sortedFields); - ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse) - .filterFields(ExtractedField.ExtractionMethod.DOC_VALUE); - if (extractedFields.getAllFields().isEmpty()) { - throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index [{}]", index); - } - return extractedFields; - } - - private static void removeFieldsWithIncompatibleTypes(Set fields, FieldCapabilitiesResponse fieldCapabilitiesResponse) { - Iterator fieldsIterator = fields.iterator(); - while (fieldsIterator.hasNext()) { - String field = fieldsIterator.next(); - Map fieldCaps = fieldCapabilitiesResponse.getField(field); - if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) { - fieldsIterator.remove(); - } - } - } - - private static void includeAndExcludeFields(Set fields, FetchSourceContext desiredFields, String index) { - if (desiredFields == null) { - return; - } - String includes = desiredFields.includes().length == 0 ? "*" : Strings.arrayToCommaDelimitedString(desiredFields.includes()); - String excludes = Strings.arrayToCommaDelimitedString(desiredFields.excludes()); - - if (Regex.isMatchAllPattern(includes) && excludes.isEmpty()) { - return; - } - try { - // If the inclusion set does not match anything, that means the user's desired fields cannot be found in - // the collection of supported field types. We should let the user know. - Set includedSet = NameResolver.newUnaliased(fields, - (ex) -> new ResourceNotFoundException(Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER, index, ex))) - .expand(includes, false); - // If the exclusion set does not match anything, that means the fields are already not present - // no need to raise if nothing matched - Set excludedSet = NameResolver.newUnaliased(fields, - (ex) -> new ResourceNotFoundException(Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER, index, ex))) - .expand(excludes, true); - - fields.retainAll(includedSet); - fields.removeAll(excludedSet); - } catch (ResourceNotFoundException ex) { - // Re-wrap our exception so that we throw the same exception type when there are no fields. - throw ExceptionsHelper.badRequestException(ex.getMessage()); - } - - } - private static void validateIndexAndExtractFields(Client client, - Map headers, String index, - FetchSourceContext desiredFields, - Set resultFields, + DataFrameAnalyticsConfig config, + boolean isTaskRestarting, ActionListener listener) { // Step 2. Extract fields (if possible) and notify listener ActionListener fieldCapabilitiesHandler = ActionListener.wrap( fieldCapabilitiesResponse -> listener.onResponse( - detectExtractedFields(index, desiredFields, resultFields, fieldCapabilitiesResponse)), + new ExtractedFieldsDetector(index, config, isTaskRestarting, fieldCapabilitiesResponse).detect()), e -> { if (e instanceof IndexNotFoundException) { listener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " @@ -219,15 +114,11 @@ private static void validateIndexAndExtractFields(Client client, FieldCapabilitiesRequest fieldCapabilitiesRequest = new FieldCapabilitiesRequest(); fieldCapabilitiesRequest.indices(index); fieldCapabilitiesRequest.fields("*"); - ClientHelper.executeWithHeaders(headers, ClientHelper.ML_ORIGIN, client, () -> { + ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, () -> { client.execute(FieldCapabilitiesAction.INSTANCE, fieldCapabilitiesRequest, fieldCapabilitiesHandler); // This response gets discarded - the listener handles the real response return null; }); } - private static Set resolveResultsFields(DataFrameAnalyticsConfig config) { - List analyses = DataFrameAnalysesUtils.readAnalyses(config.getAnalyses()); - return analyses.stream().flatMap(analysis -> analysis.getResultFields().stream()).collect(Collectors.toSet()); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java new file mode 100644 index 0000000000000..1c363b62bfaef --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.extractor; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.fieldcaps.FieldCapabilities; +import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.NameResolver; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class ExtractedFieldsDetector { + + /** + * Fields to ignore. These are mostly internal meta fields. + */ + private static final List IGNORE_FIELDS = Arrays.asList("_id", "_field_names", "_index", "_parent", "_routing", "_seq_no", + "_source", "_type", "_uid", "_version", "_feature", "_ignored"); + + /** + * The types supported by data frames + */ + private static final Set COMPATIBLE_FIELD_TYPES; + + static { + Set compatibleTypes = Stream.of(NumberFieldMapper.NumberType.values()) + .map(NumberFieldMapper.NumberType::typeName) + .collect(Collectors.toSet()); + compatibleTypes.add("scaled_float"); // have to add manually since scaled_float is in a module + + COMPATIBLE_FIELD_TYPES = Collections.unmodifiableSet(compatibleTypes); + } + + private final String index; + private final DataFrameAnalyticsConfig config; + private final boolean isTaskRestarting; + private final FieldCapabilitiesResponse fieldCapabilitiesResponse; + + ExtractedFieldsDetector(String index, DataFrameAnalyticsConfig config, boolean isTaskRestarting, + FieldCapabilitiesResponse fieldCapabilitiesResponse) { + this.index = Objects.requireNonNull(index); + this.config = Objects.requireNonNull(config); + this.isTaskRestarting = isTaskRestarting; + this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse); + } + + public ExtractedFields detect() { + Set fields = fieldCapabilitiesResponse.get().keySet(); + fields.removeAll(IGNORE_FIELDS); + + checkResultsFieldIsNotPresent(fields, index); + + // Ignore fields under the results object + fields.removeIf(field -> field.startsWith(config.getDest().getResultsField() + ".")); + + removeFieldsWithIncompatibleTypes(fields); + includeAndExcludeFields(fields, index); + List sortedFields = new ArrayList<>(fields); + // We sort the fields to ensure the checksum for each document is deterministic + Collections.sort(sortedFields); + ExtractedFields extractedFields = ExtractedFields.build(sortedFields, Collections.emptySet(), fieldCapabilitiesResponse) + .filterFields(ExtractedField.ExtractionMethod.DOC_VALUE); + if (extractedFields.getAllFields().isEmpty()) { + throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index [{}]", index); + } + return extractedFields; + } + + private void checkResultsFieldIsNotPresent(Set fields, String index) { + // If the task is restarting we do not mind the index containing the results field, we will overwrite all docs + if (isTaskRestarting == false && fields.contains(config.getDest().getResultsField())) { + throw ExceptionsHelper.badRequestException("Index [{}] already has a field that matches the {}.{} [{}];" + + " please set a different {}", index, DataFrameAnalyticsConfig.DEST.getPreferredName(), + DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(), config.getDest().getResultsField(), + DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName()); + } + } + + private void removeFieldsWithIncompatibleTypes(Set fields) { + Iterator fieldsIterator = fields.iterator(); + while (fieldsIterator.hasNext()) { + String field = fieldsIterator.next(); + Map fieldCaps = fieldCapabilitiesResponse.getField(field); + if (fieldCaps == null || COMPATIBLE_FIELD_TYPES.containsAll(fieldCaps.keySet()) == false) { + fieldsIterator.remove(); + } + } + } + + private void includeAndExcludeFields(Set fields, String index) { + FetchSourceContext analysesFields = config.getAnalysesFields(); + if (analysesFields == null) { + return; + } + String includes = analysesFields.includes().length == 0 ? "*" : Strings.arrayToCommaDelimitedString(analysesFields.includes()); + String excludes = Strings.arrayToCommaDelimitedString(analysesFields.excludes()); + + if (Regex.isMatchAllPattern(includes) && excludes.isEmpty()) { + return; + } + try { + // If the inclusion set does not match anything, that means the user's desired fields cannot be found in + // the collection of supported field types. We should let the user know. + Set includedSet = NameResolver.newUnaliased(fields, + (ex) -> new ResourceNotFoundException(Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER, index, ex))) + .expand(includes, false); + // If the exclusion set does not match anything, that means the fields are already not present + // no need to raise if nothing matched + Set excludedSet = NameResolver.newUnaliased(fields, + (ex) -> new ResourceNotFoundException(Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER, index, ex))) + .expand(excludes, true); + + fields.retainAll(includedSet); + fields.removeAll(excludedSet); + } catch (ResourceNotFoundException ex) { + // Re-wrap our exception so that we throw the same exception type when there are no fields. + throw ExceptionsHelper.badRequestException(ex.getMessage()); + } + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 36507e3da292b..98f91fb601f87 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -20,19 +20,22 @@ public class AnalyticsProcessConfig implements ToXContentObject { private static final String MEMORY_LIMIT = "memory_limit"; private static final String THREADS = "threads"; private static final String ANALYSIS = "analysis"; + private static final String RESULTS_FIELD = "results_field"; private final long rows; private final int cols; private final ByteSizeValue memoryLimit; private final int threads; private final DataFrameAnalysis analysis; + private final String resultsField; - - public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, DataFrameAnalysis analysis) { + public AnalyticsProcessConfig(long rows, int cols, ByteSizeValue memoryLimit, int threads, String resultsField, + DataFrameAnalysis analysis) { this.rows = rows; this.cols = cols; this.memoryLimit = Objects.requireNonNull(memoryLimit); this.threads = threads; + this.resultsField = Objects.requireNonNull(resultsField); this.analysis = Objects.requireNonNull(analysis); } @@ -47,6 +50,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(COLS, cols); builder.field(MEMORY_LIMIT, memoryLimit.getBytes()); builder.field(THREADS, threads); + builder.field(RESULTS_FIELD, resultsField); builder.field(ANALYSIS, analysis); builder.endObject(); return builder; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index f4c97bdd909f0..6f939c7c18b11 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -150,7 +150,7 @@ private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig conf assert dataFrameAnalyses.size() == 1; AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, - config.getModelMemoryLimit(), 1, dataFrameAnalyses.get(0)); + config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), dataFrameAnalyses.get(0)); return processConfig; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java similarity index 59% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index affc9cb0b0ca8..905349aa72840 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -10,6 +10,10 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; @@ -19,7 +23,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -27,64 +30,71 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class DataFrameDataExtractorFactoryTests extends ESTestCase { +public class ExtractedFieldsDetectorTests extends ESTestCase { - private static final String INDEX = "source_index"; - private static final FetchSourceContext EMPTY_CONTEXT = new FetchSourceContext(true, new String[0], new String[0]); - private static final Set EMPTY_RESULT_SET = Collections.emptySet(); + private static final String SOURCE_INDEX = "source_index"; + private static final String DEST_INDEX = "dest_index"; + private static final String RESULTS_FIELD = "ml"; - public void testDetectExtractedFields_GivenFloatField() { + public void testDetect_GivenFloatField() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float").build(); - ExtractedFields extractedFields = - DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); assertThat(allFields.get(0).getName(), equalTo("some_float")); } - public void testDetectExtractedFields_GivenNumericFieldWithMultipleTypes() { + public void testDetect_GivenNumericFieldWithMultipleTypes() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("some_number", "long", "integer", "short", "byte", "double", "float", "half_float", "scaled_float") .build(); - ExtractedFields extractedFields = - DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); assertThat(allFields.get(0).getName(), equalTo("some_number")); } - public void testDetectExtractedFields_GivenNonNumericField() { + public void testDetect_GivenNonNumericField() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("some_keyword", "keyword").build(); - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities)); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } - public void testDetectExtractedFields_GivenFieldWithNumericAndNonNumericTypes() { + public void testDetect_GivenFieldWithNumericAndNonNumericTypes() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("indecisive_field", "float", "keyword").build(); - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities)); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } - public void testDetectExtractedFields_GivenMultipleFields() { + public void testDetect_GivenMultipleFields() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("some_float", "float") .addAggregatableField("some_long", "long") .addAggregatableField("some_keyword", "keyword") .build(); - ExtractedFields extractedFields = - DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(2)); @@ -92,16 +102,18 @@ public void testDetectExtractedFields_GivenMultipleFields() { containsInAnyOrder("some_float", "some_long")); } - public void testDetectExtractedFields_GivenIgnoredField() { + public void testDetect_GivenIgnoredField() { FieldCapabilitiesResponse fieldCapabilities= new MockFieldCapsResponseBuilder() .addAggregatableField("_id", "float").build(); - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities)); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } - public void testDetectExtractedFields_ShouldSortFieldsAlphabetically() { + public void testDetect_ShouldSortFieldsAlphabetically() { int fieldCount = randomIntBetween(10, 20); List fields = new ArrayList<>(); for (int i = 0; i < fieldCount; i++) { @@ -116,8 +128,9 @@ public void testDetectExtractedFields_ShouldSortFieldsAlphabetically() { } FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); - ExtractedFields extractedFields = - DataFrameDataExtractorFactory.detectExtractedFields(INDEX, EMPTY_CONTEXT, EMPTY_RESULT_SET, fieldCapabilities); + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); @@ -131,8 +144,11 @@ public void testDetectedExtractedFields_GivenIncludeWithMissingField() { .build(); FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]); - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, EMPTY_RESULT_SET, fieldCapabilities)); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index] with name [your_field1]")); } @@ -143,8 +159,10 @@ public void testDetectedExtractedFields_GivenExcludeAllValidFields() { .build(); FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"}); - ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, EMPTY_RESULT_SET, fieldCapabilities)); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -157,29 +175,63 @@ public void testDetectedExtractedFields_GivenInclusionsAndExclusions() { .build(); FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); - ExtractedFields extractedFields = - DataFrameDataExtractorFactory.detectExtractedFields(INDEX, desiredFields, EMPTY_RESULT_SET, fieldCapabilities); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2"))); } - public void testDetectedExtractedFields_GivenAResultField() { + public void testDetectedExtractedFields_GivenIndexContainsResultsField() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField(RESULTS_FIELD, "float") + .addAggregatableField("my_field1", "float") + .addAggregatableField("your_field2", "float") + .addAggregatableField("your_keyword", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); + + assertThat(e.getMessage(), equalTo("Index [source_index] already has a field that matches the dest.results_field [ml]; " + + "please set a different results_field")); + } + + public void testDetectedExtractedFields_GivenIndexContainsResultsFieldAndTaskIsRestarting() { FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() - .addAggregatableField("outlier_score", "float") + .addAggregatableField(RESULTS_FIELD + ".outlier_score", "float") .addAggregatableField("my_field1", "float") .addAggregatableField("your_field2", "float") .addAggregatableField("your_keyword", "keyword") .build(); - ExtractedFields extractedFields = DataFrameDataExtractorFactory.detectExtractedFields(INDEX, - EMPTY_CONTEXT, - Collections.singleton("outlier_score"), - fieldCapabilities); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), true, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) .collect(Collectors.toList()); assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2"))); } + private static DataFrameAnalyticsConfig buildAnalyticsConfig() { + return buildAnalyticsConfig(null); + } + + private static DataFrameAnalyticsConfig buildAnalyticsConfig(FetchSourceContext analysesFields) { + return new DataFrameAnalyticsConfig.Builder("foo") + .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null)) + .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) + .setAnalysesFields(analysesFields) + .setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig( + Collections.singletonMap("outlier_detection", Collections.emptyMap())))) + .build(); + } + private static class MockFieldCapsResponseBuilder { private final Map> fieldCaps = new HashMap<>(); From ef71e4507eb285fb0055074834e640db16dad560 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Thu, 11 Apr 2019 21:30:22 +0100 Subject: [PATCH 33/67] [ML][FEATURE] Use memory limit to assign data frame analytics jobs (#40892) This change switches data frame analytics jobs from simple assignment to the first available ML node to assignment based on the memory requirement of the data frame analytics job and the available memory on each ML nodes. The memory requirement of a data frame analytics job is now defined to be its model memory limit plus 20MB for process overhead. Available memory on each ML node takes into account running anomaly detector jobs as well as data frame analytics jobs. To facilitate this the shared logic has been moved from TransportOpenJobAction to a new class, JobNodeSelector. --- .../action/StartDataFrameAnalyticsAction.java | 6 +- .../dataframe/DataFrameAnalyticsConfig.java | 1 + .../xpack/ml/MachineLearning.java | 3 +- .../ml/action/TransportOpenJobAction.java | 280 +-------- ...ransportStartDataFrameAnalyticsAction.java | 80 ++- .../xpack/ml/job/JobNodeSelector.java | 328 ++++++++++ .../xpack/ml/process/MlMemoryTracker.java | 60 +- .../action/TransportOpenJobActionTests.java | 394 +----------- .../xpack/ml/job/JobNodeSelectorTests.java | 575 ++++++++++++++++++ 9 files changed, 1053 insertions(+), 674 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java index aec69e3f4e9a3..e3a699ac27497 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java @@ -158,6 +158,9 @@ static class RequestBuilder extends ActionRequestBuilder PARSER = new ConstructingObjectParser<>( MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, a -> new TaskParams((String) a[0])); @@ -186,8 +189,7 @@ public String getWriteableName() { @Override public Version getMinimalSupportedVersion() { - // TODO Update to first released version - return Version.CURRENT; + return VERSION_INTRODUCED; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index 54f0fb646ba81..f1fe8ced110a5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -37,6 +37,7 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { public static final ByteSizeValue DEFAULT_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.GB); public static final ByteSizeValue MIN_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.MB); + public static final ByteSizeValue PROCESS_MEMORY_OVERHEAD = new ByteSizeValue(20, ByteSizeUnit.MB); public static final ParseField ID = new ParseField("id"); public static final ParseField SOURCE = new ParseField("source"); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 5b8f4a103ffa5..98417cea897b7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -541,7 +541,8 @@ public List> getPersistentTasksExecutor(ClusterServic new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get(), memoryTracker.get(), client), new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(datafeedManager.get()), - new TransportStartDataFrameAnalyticsAction.TaskExecutor(dataFrameAnalyticsManager.get()) + new TransportStartDataFrameAnalyticsAction.TaskExecutor(settings, clusterService, dataFrameAnalyticsManager.get(), + memoryTracker.get()) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java index 99df655114ca2..c134d4f96b731 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceAlreadyExistsException; -import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.IndicesOptions; @@ -24,7 +23,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; @@ -45,7 +43,6 @@ import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; -import org.elasticsearch.xpack.core.ml.job.config.DetectionRule; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; @@ -54,16 +51,15 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MlConfigMigrationEligibilityCheck; +import org.elasticsearch.xpack.ml.job.JobNodeSelector; import org.elasticsearch.xpack.ml.job.persistence.JobConfigProvider; import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import java.util.ArrayList; -import java.util.Collection; -import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.Objects; import java.util.function.Predicate; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -82,9 +78,6 @@ To ensure that a subsequent close job call will see that same task status (and s */ public class TransportOpenJobAction extends TransportMasterNodeAction { - private static final PersistentTasksCustomMetaData.Assignment AWAITING_LAZY_ASSIGNMENT = - new PersistentTasksCustomMetaData.Assignment(null, "persistent task is awaiting node assignment."); - static final PersistentTasksCustomMetaData.Assignment AWAITING_MIGRATION = new PersistentTasksCustomMetaData.Assignment(null, "job cannot be assigned until it has been migrated."); @@ -130,218 +123,6 @@ static void validate(String jobId, Job job) { } } - static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String jobId, Job job, - ClusterState clusterState, - int dynamicMaxOpenJobs, - int maxConcurrentJobAllocations, - int maxMachineMemoryPercent, - MlMemoryTracker memoryTracker, - boolean isMemoryTrackerRecentlyRefreshed, - Logger logger) { - // TODO: remove in 8.0.0 - boolean allNodesHaveDynamicMaxWorkers = clusterState.getNodes().getMinNodeVersion().onOrAfter(Version.V_7_1_0); - - // Try to allocate jobs according to memory usage, but if that's not possible (maybe due to a mixed version cluster or maybe - // because of some weird OS problem) then fall back to the old mechanism of only considering numbers of assigned jobs - boolean allocateByMemory = isMemoryTrackerRecentlyRefreshed; - if (isMemoryTrackerRecentlyRefreshed == false) { - logger.warn("Falling back to allocating job [{}] by job counts because a memory requirement refresh could not be scheduled", - jobId); - } - - List reasons = new LinkedList<>(); - long maxAvailableCount = Long.MIN_VALUE; - long maxAvailableMemory = Long.MIN_VALUE; - DiscoveryNode minLoadedNodeByCount = null; - DiscoveryNode minLoadedNodeByMemory = null; - PersistentTasksCustomMetaData persistentTasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); - for (DiscoveryNode node : clusterState.getNodes()) { - if (MachineLearning.isMlNode(node) == false) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameOrId(node) - + "], because this node isn't a ml node."; - logger.trace(reason); - reasons.add(reason); - continue; - } - - if (nodeSupportsModelSnapshotVersion(node, job) == false) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndVersion(node) - + "], because the job's model snapshot requires a node of version [" - + job.getModelSnapshotMinVersion() + "] or higher"; - logger.trace(reason); - reasons.add(reason); - continue; - } - - Set compatibleJobTypes = Job.getCompatibleJobTypes(node.getVersion()); - if (compatibleJobTypes.contains(job.getJobType()) == false) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndVersion(node) + - "], because this node does not support jobs of type [" + job.getJobType() + "]"; - logger.trace(reason); - reasons.add(reason); - continue; - } - - if (jobHasRules(job) && node.getVersion().before(DetectionRule.VERSION_INTRODUCED)) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndVersion(node) + "], because jobs using " + - "custom_rules require a node of version [" + DetectionRule.VERSION_INTRODUCED + "] or higher"; - logger.trace(reason); - reasons.add(reason); - continue; - } - - long numberOfAssignedJobs = 0; - int numberOfAllocatingJobs = 0; - long assignedJobMemory = 0; - if (persistentTasks != null) { - // find all the job tasks assigned to this node - Collection> assignedTasks = persistentTasks.findTasks( - MlTasks.JOB_TASK_NAME, task -> node.getId().equals(task.getExecutorNode())); - for (PersistentTasksCustomMetaData.PersistentTask assignedTask : assignedTasks) { - JobState jobState = MlTasks.getJobStateModifiedForReassignments(assignedTask); - if (jobState.isAnyOf(JobState.CLOSED, JobState.FAILED) == false) { - // Don't count CLOSED or FAILED jobs, as they don't consume native memory - ++numberOfAssignedJobs; - if (jobState == JobState.OPENING) { - ++numberOfAllocatingJobs; - } - OpenJobAction.JobParams params = (OpenJobAction.JobParams) assignedTask.getParams(); - Long jobMemoryRequirement = memoryTracker.getAnomalyDetectorJobMemoryRequirement(params.getJobId()); - if (jobMemoryRequirement == null) { - allocateByMemory = false; - logger.debug("Falling back to allocating job [{}] by job counts because " + - "the memory requirement for job [{}] was not available", jobId, params.getJobId()); - } else { - assignedJobMemory += jobMemoryRequirement; - } - } - } - } - if (numberOfAllocatingJobs >= maxConcurrentJobAllocations) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) - + "], because node exceeds [" + numberOfAllocatingJobs + - "] the maximum number of jobs [" + maxConcurrentJobAllocations + "] in opening state"; - logger.trace(reason); - reasons.add(reason); - continue; - } - - Map nodeAttributes = node.getAttributes(); - int maxNumberOfOpenJobs = dynamicMaxOpenJobs; - // TODO: remove this in 8.0.0 - if (allNodesHaveDynamicMaxWorkers == false) { - String maxNumberOfOpenJobsStr = nodeAttributes.get(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR); - try { - maxNumberOfOpenJobs = Integer.parseInt(maxNumberOfOpenJobsStr); - } catch (NumberFormatException e) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + "], because " + - MachineLearning.MAX_OPEN_JOBS_NODE_ATTR + " attribute [" + maxNumberOfOpenJobsStr + "] is not an integer"; - logger.trace(reason); - reasons.add(reason); - continue; - } - } - long availableCount = maxNumberOfOpenJobs - numberOfAssignedJobs; - if (availableCount == 0) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) - + "], because this node is full. Number of opened jobs [" + numberOfAssignedJobs - + "], " + MAX_OPEN_JOBS_PER_NODE.getKey() + " [" + maxNumberOfOpenJobs + "]"; - logger.trace(reason); - reasons.add(reason); - continue; - } - - if (maxAvailableCount < availableCount) { - maxAvailableCount = availableCount; - minLoadedNodeByCount = node; - } - - String machineMemoryStr = nodeAttributes.get(MachineLearning.MACHINE_MEMORY_NODE_ATTR); - long machineMemory; - try { - machineMemory = Long.parseLong(machineMemoryStr); - } catch (NumberFormatException e) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + "], because " + - MachineLearning.MACHINE_MEMORY_NODE_ATTR + " attribute [" + machineMemoryStr + "] is not a long"; - logger.trace(reason); - reasons.add(reason); - continue; - } - - if (allocateByMemory) { - if (machineMemory > 0) { - long maxMlMemory = machineMemory * maxMachineMemoryPercent / 100; - Long estimatedMemoryFootprint = memoryTracker.getAnomalyDetectorJobMemoryRequirement(jobId); - if (estimatedMemoryFootprint != null) { - long availableMemory = maxMlMemory - assignedJobMemory; - if (estimatedMemoryFootprint > availableMemory) { - String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + - "], because this node has insufficient available memory. Available memory for ML [" + maxMlMemory + - "], memory required by existing jobs [" + assignedJobMemory + - "], estimated memory required for this job [" + estimatedMemoryFootprint + "]"; - logger.trace(reason); - reasons.add(reason); - continue; - } - - if (maxAvailableMemory < availableMemory) { - maxAvailableMemory = availableMemory; - minLoadedNodeByMemory = node; - } - } else { - // If we cannot get the job memory requirement, - // fall back to simply allocating by job count - allocateByMemory = false; - logger.debug("Falling back to allocating job [{}] by job counts because its memory requirement was not available", - jobId); - } - } else { - // If we cannot get the available memory on any machine in - // the cluster, fall back to simply allocating by job count - allocateByMemory = false; - logger.debug("Falling back to allocating job [{}] by job counts because machine memory was not available for node [{}]", - jobId, nodeNameAndMlAttributes(node)); - } - } - } - DiscoveryNode minLoadedNode = allocateByMemory ? minLoadedNodeByMemory : minLoadedNodeByCount; - if (minLoadedNode != null) { - logger.debug("selected node [{}] for job [{}]", minLoadedNode, jobId); - return new PersistentTasksCustomMetaData.Assignment(minLoadedNode.getId(), ""); - } else { - String explanation = String.join("|", reasons); - logger.debug("no node selected for job [{}], reasons [{}]", jobId, explanation); - return new PersistentTasksCustomMetaData.Assignment(null, explanation); - } - } - - static String nodeNameOrId(DiscoveryNode node) { - String nodeNameOrID = node.getName(); - if (Strings.isNullOrEmpty(nodeNameOrID)) { - nodeNameOrID = node.getId(); - } - return nodeNameOrID; - } - - static String nodeNameAndVersion(DiscoveryNode node) { - String nodeNameOrID = nodeNameOrId(node); - StringBuilder builder = new StringBuilder("{").append(nodeNameOrID).append('}'); - builder.append('{').append("version=").append(node.getVersion()).append('}'); - return builder.toString(); - } - - static String nodeNameAndMlAttributes(DiscoveryNode node) { - String nodeNameOrID = nodeNameOrId(node); - - StringBuilder builder = new StringBuilder("{").append(nodeNameOrID).append('}'); - for (Map.Entry entry : node.getAttributes().entrySet()) { - if (entry.getKey().startsWith("ml.") || entry.getKey().equals("node.ml")) { - builder.append('{').append(entry).append('}'); - } - } - return builder.toString(); - } - static String[] indicesOfInterest(String resultsIndex) { if (resultsIndex == null) { return new String[]{AnomalyDetectorsIndex.jobStateIndexPattern(), MlMetaIndex.INDEX_NAME}; @@ -381,6 +162,24 @@ private static boolean jobHasRules(Job job) { return job.getAnalysisConfig().getDetectors().stream().anyMatch(d -> d.getRules().isEmpty() == false); } + public static String nodeFilter(DiscoveryNode node, Job job) { + + String jobId = job.getId(); + + if (TransportOpenJobAction.nodeSupportsModelSnapshotVersion(node, job) == false) { + return "Not opening job [" + jobId + "] on node [" + JobNodeSelector.nodeNameAndVersion(node) + + "], because the job's model snapshot requires a node of version [" + + job.getModelSnapshotMinVersion() + "] or higher"; + } + + if (Job.getCompatibleJobTypes(node.getVersion()).contains(job.getJobType()) == false) { + return "Not opening job [" + jobId + "] on node [" + JobNodeSelector.nodeNameAndVersion(node) + + "], because this node does not support jobs of type [" + job.getJobType() + "]"; + } + + return null; + } + @Override protected String executor() { // This api doesn't do heavy or blocking operations (just delegates PersistentTasksService), @@ -541,7 +340,6 @@ public static class OpenJobPersistentTasksExecutor extends PersistentTasksExecut private final AutodetectProcessManager autodetectProcessManager; private final MlMemoryTracker memoryTracker; private final Client client; - private final ClusterService clusterService; private volatile int maxConcurrentJobAllocations; private volatile int maxMachineMemoryPercent; @@ -553,14 +351,13 @@ public OpenJobPersistentTasksExecutor(Settings settings, ClusterService clusterS AutodetectProcessManager autodetectProcessManager, MlMemoryTracker memoryTracker, Client client) { super(MlTasks.JOB_TASK_NAME, MachineLearning.UTILITY_THREAD_POOL_NAME); - this.autodetectProcessManager = autodetectProcessManager; - this.memoryTracker = memoryTracker; - this.client = client; + this.autodetectProcessManager = Objects.requireNonNull(autodetectProcessManager); + this.memoryTracker = Objects.requireNonNull(memoryTracker); + this.client = Objects.requireNonNull(client); this.maxConcurrentJobAllocations = MachineLearning.CONCURRENT_JOB_ALLOCATIONS.get(settings); this.maxMachineMemoryPercent = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings); this.maxLazyMLNodes = MachineLearning.MAX_LAZY_ML_NODES.get(settings); this.maxOpenJobs = MAX_OPEN_JOBS_PER_NODE.get(settings); - this.clusterService = clusterService; clusterService.getClusterSettings() .addSettingsUpdateConsumer(MachineLearning.CONCURRENT_JOB_ALLOCATIONS, this::setMaxConcurrentJobAllocations); clusterService.getClusterSettings() @@ -604,28 +401,11 @@ public PersistentTasksCustomMetaData.Assignment getAssignment(OpenJobAction.JobP } } - PersistentTasksCustomMetaData.Assignment assignment = selectLeastLoadedMlNode(jobId, - params.getJob(), - clusterState, - maxOpenJobs, - maxConcurrentJobAllocations, - maxMachineMemoryPercent, - memoryTracker, - isMemoryTrackerRecentlyRefreshed, - logger); - if (assignment.getExecutorNode() == null) { - int numMlNodes = 0; - for (DiscoveryNode node : clusterState.getNodes()) { - if (MachineLearning.isMlNode(node)) { - numMlNodes++; - } - } - - if (numMlNodes < maxLazyMLNodes) { // Means we have lazy nodes left to allocate - assignment = AWAITING_LAZY_ASSIGNMENT; - } - } - return assignment; + Job job = params.getJob(); + JobNodeSelector jobNodeSelector = new JobNodeSelector(clusterState, jobId, MlTasks.JOB_TASK_NAME, memoryTracker, + maxLazyMLNodes, node -> nodeFilter(node, job)); + return jobNodeSelector.selectNode( + maxOpenJobs, maxConcurrentJobAllocations, maxMachineMemoryPercent, isMemoryTrackerRecentlyRefreshed); } @Override @@ -640,7 +420,7 @@ public void validate(OpenJobAction.JobParams params, ClusterState clusterState) throw makeCurrentlyBeingUpgradedException(logger, params.getJobId(), assignment.getExplanation()); } - if (assignment.getExecutorNode() == null && assignment.equals(AWAITING_LAZY_ASSIGNMENT) == false) { + if (assignment.getExecutorNode() == null && assignment.equals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT) == false) { throw makeNoSuitableNodesException(logger, params.getJobId(), assignment.getExplanation()); } } @@ -756,7 +536,7 @@ public boolean test(PersistentTasksCustomMetaData.PersistentTask persistentTa PersistentTasksCustomMetaData.Assignment assignment = persistentTask.getAssignment(); // This means we are awaiting a new node to be spun up, ok to return back to the user to await node creation - if (assignment != null && assignment.equals(AWAITING_LAZY_ASSIGNMENT)) { + if (assignment != null && assignment.equals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT)) { return true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index 697dde824ddb3..fb10074c25763 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -23,6 +23,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.license.XPackLicenseState; @@ -37,6 +38,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -46,12 +48,16 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; +import org.elasticsearch.xpack.ml.job.JobNodeSelector; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import java.util.Map; import java.util.Objects; import java.util.function.Predicate; +import static org.elasticsearch.xpack.core.ml.MlTasks.AWAITING_UPGRADE; +import static org.elasticsearch.xpack.ml.MachineLearning.MAX_OPEN_JOBS_PER_NODE; + /** * Starts the persistent task for running data frame analytics. * @@ -201,6 +207,12 @@ public boolean test(PersistentTasksCustomMetaData.PersistentTask persistentTa } PersistentTasksCustomMetaData.Assignment assignment = persistentTask.getAssignment(); + + // This means we are awaiting a new node to be spun up, ok to return back to the user to await node creation + if (assignment != null && assignment.equals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT)) { + return true; + } + if (assignment != null && assignment.equals(PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT) == false && assignment.isAssigned() == false) { // Assignment has failed despite passing our "fast fail" validation @@ -265,10 +277,24 @@ public Long getReindexingTaskId() { public static class TaskExecutor extends PersistentTasksExecutor { private final DataFrameAnalyticsManager manager; + private final MlMemoryTracker memoryTracker; + + private volatile int maxMachineMemoryPercent; + private volatile int maxLazyMLNodes; + private volatile int maxOpenJobs; - public TaskExecutor(DataFrameAnalyticsManager manager) { + public TaskExecutor(Settings settings, ClusterService clusterService, DataFrameAnalyticsManager manager, + MlMemoryTracker memoryTracker) { super(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, MachineLearning.UTILITY_THREAD_POOL_NAME); this.manager = Objects.requireNonNull(manager); + this.memoryTracker = Objects.requireNonNull(memoryTracker); + this.maxMachineMemoryPercent = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings); + this.maxLazyMLNodes = MachineLearning.MAX_LAZY_ML_NODES.get(settings); + this.maxOpenJobs = MAX_OPEN_JOBS_PER_NODE.get(settings); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, this::setMaxMachineMemoryPercent); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_LAZY_ML_NODES, this::setMaxLazyMLNodes); + clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_OPEN_JOBS_PER_NODE, this::setMaxOpenJobs); } @Override @@ -280,10 +306,32 @@ protected AllocatedPersistentTask createTask( } @Override - protected DiscoveryNode selectLeastLoadedNode(ClusterState clusterState, Predicate selector) { - // For starters, let's just select the least loaded ML node - // TODO implement memory-based load balancing - return super.selectLeastLoadedNode(clusterState, MachineLearning::isMlNode); + public PersistentTasksCustomMetaData.Assignment getAssignment(StartDataFrameAnalyticsAction.TaskParams params, + ClusterState clusterState) { + + // If we are waiting for an upgrade to complete, we should not assign to a node + if (MlMetadata.getMlMetadata(clusterState).isUpgradeMode()) { + return AWAITING_UPGRADE; + } + + String id = params.getId(); + + boolean isMemoryTrackerRecentlyRefreshed = memoryTracker.isRecentlyRefreshed(); + if (isMemoryTrackerRecentlyRefreshed == false) { + boolean scheduledRefresh = memoryTracker.asyncRefresh(); + if (scheduledRefresh) { + String reason = "Not opening job [" + id + "] because job memory requirements are stale - refresh requested"; + LOGGER.debug(reason); + return new PersistentTasksCustomMetaData.Assignment(null, reason); + } + } + + JobNodeSelector jobNodeSelector = new JobNodeSelector(clusterState, id, MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryTracker, + maxLazyMLNodes, node -> nodeFilter(node, id)); + // Pass an effectively infinite value for max concurrent opening jobs, because data frame analytics jobs do + // not have an "opening" state so would never be rejected for causing too many jobs in the "opening" state + return jobNodeSelector.selectNode( + maxOpenJobs, Integer.MAX_VALUE, maxMachineMemoryPercent, isMemoryTrackerRecentlyRefreshed); } @Override @@ -306,7 +354,29 @@ protected void nodeOperation(AllocatedPersistentTask task, StartDataFrameAnalyti } else { manager.execute((DataFrameAnalyticsTask)task, analyticsTaskState.getState()); } + } + + public static String nodeFilter(DiscoveryNode node, String id) { + + if (node.getVersion().before(StartDataFrameAnalyticsAction.TaskParams.VERSION_INTRODUCED)) { + return "Not opening job [" + id + "] on node [" + JobNodeSelector.nodeNameAndVersion(node) + + "], because the data frame analytics requires a node of version [" + + StartDataFrameAnalyticsAction.TaskParams.VERSION_INTRODUCED + "] or higher"; + } + + return null; + } + + void setMaxMachineMemoryPercent(int maxMachineMemoryPercent) { + this.maxMachineMemoryPercent = maxMachineMemoryPercent; + } + + void setMaxLazyMLNodes(int maxLazyMLNodes) { + this.maxLazyMLNodes = maxLazyMLNodes; + } + void setMaxOpenJobs(int maxOpenJobs) { + this.maxOpenJobs = maxOpenJobs; } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java new file mode 100644 index 0000000000000..1ea80c8b95dc7 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java @@ -0,0 +1,328 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.job; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.Version; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Strings; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.OpenJobAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.job.config.JobState; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; + +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; + +import static org.elasticsearch.xpack.ml.MachineLearning.MAX_OPEN_JOBS_PER_NODE; + +/** + * Class that contains the logic to decide which node to assign each job to. + * + * The assignment rules are as follows: + * + * 1. Reject nodes that are not ML nodes + * 2. Reject nodes for which the node filter returns a rejection reason + * 3. Reject nodes where the new job would result in more than the permitted number of concurrent "opening" jobs + * 4. Reject nodes where the new job would result in more than the permitted number of assigned jobs + * 5. If assigning by memory, reject nodes where the new job would result in the permitted amount of memory being exceeded + * 6. If assigning by memory, pick the node that remains after rejections that has the most remaining memory + * 7. If assigning by count, pick the node that remains after rejections that has the fewest jobs assigned to it + * + * The decision on whether to assign by memory or by count is: + * - If values are available for every node's memory size and every job's memory requirement then assign by memory + * - Otherwise assign by count + */ +public class JobNodeSelector { + + public static final PersistentTasksCustomMetaData.Assignment AWAITING_LAZY_ASSIGNMENT = + new PersistentTasksCustomMetaData.Assignment(null, "persistent task is awaiting node assignment."); + + private static final Logger logger = LogManager.getLogger(JobNodeSelector.class); + + private final String jobId; + private final String taskName; + private final ClusterState clusterState; + private final MlMemoryTracker memoryTracker; + private final Function nodeFilter; + private final int maxLazyNodes; + + /** + * @param nodeFilter Optionally a function that returns a reason beyond the general + * reasons why a job cannot be assigned to a particular node. May + * be null if no such function is needed. + */ + public JobNodeSelector(ClusterState clusterState, String jobId, String taskName, MlMemoryTracker memoryTracker, int maxLazyNodes, + Function nodeFilter) { + this.jobId = Objects.requireNonNull(jobId); + this.taskName = Objects.requireNonNull(taskName); + this.clusterState = Objects.requireNonNull(clusterState); + this.memoryTracker = Objects.requireNonNull(memoryTracker); + this.maxLazyNodes = maxLazyNodes; + this.nodeFilter = node -> { + if (MachineLearning.isMlNode(node)) { + return (nodeFilter != null) ? nodeFilter.apply(node) : null; + } + return "Not opening job [" + jobId + "] on node [" + nodeNameOrId(node) + "], because this node isn't a ml node."; + }; + } + + public PersistentTasksCustomMetaData.Assignment selectNode(int dynamicMaxOpenJobs, int maxConcurrentJobAllocations, + int maxMachineMemoryPercent, boolean isMemoryTrackerRecentlyRefreshed) { + // TODO: remove in 8.0.0 + boolean allNodesHaveDynamicMaxWorkers = clusterState.getNodes().getMinNodeVersion().onOrAfter(Version.V_7_1_0); + + // Try to allocate jobs according to memory usage, but if that's not possible (maybe due to a mixed version cluster or maybe + // because of some weird OS problem) then fall back to the old mechanism of only considering numbers of assigned jobs + boolean allocateByMemory = isMemoryTrackerRecentlyRefreshed; + if (isMemoryTrackerRecentlyRefreshed == false) { + logger.warn("Falling back to allocating job [{}] by job counts because a memory requirement refresh could not be scheduled", + jobId); + } + + List reasons = new LinkedList<>(); + long maxAvailableCount = Long.MIN_VALUE; + long maxAvailableMemory = Long.MIN_VALUE; + DiscoveryNode minLoadedNodeByCount = null; + DiscoveryNode minLoadedNodeByMemory = null; + PersistentTasksCustomMetaData persistentTasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + for (DiscoveryNode node : clusterState.getNodes()) { + + // First check conditions that would rule out the node regardless of what other tasks are assigned to it + String reason = nodeFilter.apply(node); + if (reason != null) { + logger.trace(reason); + reasons.add(reason); + continue; + } + + // Assuming the node is elligible at all, check loading + CurrentLoad currentLoad = calculateCurrentLoadForNode(node, persistentTasks, allocateByMemory); + allocateByMemory = currentLoad.allocateByMemory; + + if (currentLoad.numberOfAllocatingJobs >= maxConcurrentJobAllocations) { + reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + "], because node exceeds [" + + currentLoad.numberOfAllocatingJobs + "] the maximum number of jobs [" + maxConcurrentJobAllocations + + "] in opening state"; + logger.trace(reason); + reasons.add(reason); + continue; + } + + Map nodeAttributes = node.getAttributes(); + int maxNumberOfOpenJobs = dynamicMaxOpenJobs; + // TODO: remove this in 8.0.0 + if (allNodesHaveDynamicMaxWorkers == false) { + String maxNumberOfOpenJobsStr = nodeAttributes.get(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR); + try { + maxNumberOfOpenJobs = Integer.parseInt(maxNumberOfOpenJobsStr); + } catch (NumberFormatException e) { + reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + "], because " + + MachineLearning.MAX_OPEN_JOBS_NODE_ATTR + " attribute [" + maxNumberOfOpenJobsStr + "] is not an integer"; + logger.trace(reason); + reasons.add(reason); + continue; + } + } + long availableCount = maxNumberOfOpenJobs - currentLoad.numberOfAssignedJobs; + if (availableCount == 0) { + reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + + "], because this node is full. Number of opened jobs [" + currentLoad.numberOfAssignedJobs + + "], " + MAX_OPEN_JOBS_PER_NODE.getKey() + " [" + maxNumberOfOpenJobs + "]"; + logger.trace(reason); + reasons.add(reason); + continue; + } + + if (maxAvailableCount < availableCount) { + maxAvailableCount = availableCount; + minLoadedNodeByCount = node; + } + + String machineMemoryStr = nodeAttributes.get(MachineLearning.MACHINE_MEMORY_NODE_ATTR); + long machineMemory; + try { + machineMemory = Long.parseLong(machineMemoryStr); + } catch (NumberFormatException e) { + reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + "], because " + + MachineLearning.MACHINE_MEMORY_NODE_ATTR + " attribute [" + machineMemoryStr + "] is not a long"; + logger.trace(reason); + reasons.add(reason); + continue; + } + + if (allocateByMemory) { + if (machineMemory > 0) { + long maxMlMemory = machineMemory * maxMachineMemoryPercent / 100; + Long estimatedMemoryFootprint = memoryTracker.getJobMemoryRequirement(taskName, jobId); + if (estimatedMemoryFootprint != null) { + long availableMemory = maxMlMemory - currentLoad.assignedJobMemory; + if (estimatedMemoryFootprint > availableMemory) { + reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + + "], because this node has insufficient available memory. Available memory for ML [" + maxMlMemory + + "], memory required by existing jobs [" + currentLoad.assignedJobMemory + + "], estimated memory required for this job [" + estimatedMemoryFootprint + "]"; + logger.trace(reason); + reasons.add(reason); + continue; + } + + if (maxAvailableMemory < availableMemory) { + maxAvailableMemory = availableMemory; + minLoadedNodeByMemory = node; + } + } else { + // If we cannot get the job memory requirement, + // fall back to simply allocating by job count + allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because its memory requirement was not available", + jobId); + } + } else { + // If we cannot get the available memory on any machine in + // the cluster, fall back to simply allocating by job count + allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because machine memory was not available for node [{}]", + jobId, nodeNameAndMlAttributes(node)); + } + } + } + return createAssignment(allocateByMemory ? minLoadedNodeByMemory : minLoadedNodeByCount, reasons); + } + + private PersistentTasksCustomMetaData.Assignment createAssignment(DiscoveryNode minLoadedNode, List reasons) { + if (minLoadedNode == null) { + String explanation = String.join("|", reasons); + logger.debug("no node selected for job [{}], reasons [{}]", jobId, explanation); + return considerLazyAssignment(new PersistentTasksCustomMetaData.Assignment(null, explanation)); + } + logger.debug("selected node [{}] for job [{}]", minLoadedNode, jobId); + return new PersistentTasksCustomMetaData.Assignment(minLoadedNode.getId(), ""); + } + + PersistentTasksCustomMetaData.Assignment considerLazyAssignment(PersistentTasksCustomMetaData.Assignment currentAssignment) { + + assert currentAssignment.getExecutorNode() == null; + + int numMlNodes = 0; + for (DiscoveryNode node : clusterState.getNodes()) { + if (MachineLearning.isMlNode(node)) { + numMlNodes++; + } + } + + if (numMlNodes < maxLazyNodes) { // Means we have lazy nodes left to allocate + return AWAITING_LAZY_ASSIGNMENT; + } + + return currentAssignment; + } + + private CurrentLoad calculateCurrentLoadForNode(DiscoveryNode node, PersistentTasksCustomMetaData persistentTasks, + final boolean allocateByMemory) { + CurrentLoad result = new CurrentLoad(allocateByMemory); + + if (persistentTasks != null) { + // find all the anomaly detector job tasks assigned to this node + Collection> assignedAnomalyDetectorTasks = persistentTasks.findTasks( + MlTasks.JOB_TASK_NAME, task -> node.getId().equals(task.getExecutorNode())); + for (PersistentTasksCustomMetaData.PersistentTask assignedTask : assignedAnomalyDetectorTasks) { + JobState jobState = MlTasks.getJobStateModifiedForReassignments(assignedTask); + if (jobState.isAnyOf(JobState.CLOSED, JobState.FAILED) == false) { + // Don't count CLOSED or FAILED jobs, as they don't consume native memory + ++result.numberOfAssignedJobs; + if (jobState == JobState.OPENING) { + ++result.numberOfAllocatingJobs; + } + OpenJobAction.JobParams params = (OpenJobAction.JobParams) assignedTask.getParams(); + Long jobMemoryRequirement = memoryTracker.getAnomalyDetectorJobMemoryRequirement(params.getJobId()); + if (jobMemoryRequirement == null) { + result.allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because " + + "the memory requirement for job [{}] was not available", jobId, params.getJobId()); + } else { + logger.debug("adding " + jobMemoryRequirement); + result.assignedJobMemory += jobMemoryRequirement; + } + } + } + // find all the data frame analytics job tasks assigned to this node + Collection> assignedAnalyticsTasks = persistentTasks.findTasks( + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, task -> node.getId().equals(task.getExecutorNode())); + for (PersistentTasksCustomMetaData.PersistentTask assignedTask : assignedAnalyticsTasks) { + DataFrameAnalyticsState dataFrameAnalyticsState = ((DataFrameAnalyticsTaskState) assignedTask.getState()).getState(); + // TODO: skip FAILED here too if such a state is ever added + if (dataFrameAnalyticsState != DataFrameAnalyticsState.STOPPED) { + // The native process is only running in the ANALYZING and STOPPING states, but in the STARTED + // and REINDEXING states we're committed to using the memory soon, so account for it here + ++result.numberOfAssignedJobs; + StartDataFrameAnalyticsAction.TaskParams params = + (StartDataFrameAnalyticsAction.TaskParams) assignedTask.getParams(); + Long jobMemoryRequirement = memoryTracker.getDataFrameAnalyticsJobMemoryRequirement(params.getId()); + if (jobMemoryRequirement == null) { + result.allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because " + + "the memory requirement for job [{}] was not available", jobId, params.getId()); + } else { + result.assignedJobMemory += jobMemoryRequirement; + } + } + } + } + + return result; + } + + static String nodeNameOrId(DiscoveryNode node) { + String nodeNameOrID = node.getName(); + if (Strings.isNullOrEmpty(nodeNameOrID)) { + nodeNameOrID = node.getId(); + } + return nodeNameOrID; + } + + public static String nodeNameAndVersion(DiscoveryNode node) { + String nodeNameOrID = nodeNameOrId(node); + StringBuilder builder = new StringBuilder("{").append(nodeNameOrID).append('}'); + builder.append('{').append("version=").append(node.getVersion()).append('}'); + return builder.toString(); + } + + static String nodeNameAndMlAttributes(DiscoveryNode node) { + String nodeNameOrID = nodeNameOrId(node); + + StringBuilder builder = new StringBuilder("{").append(nodeNameOrID).append('}'); + for (Map.Entry entry : node.getAttributes().entrySet()) { + if (entry.getKey().startsWith("ml.") || entry.getKey().equals("node.ml")) { + builder.append('{').append(entry).append('}'); + } + } + return builder.toString(); + } + + private static class CurrentLoad { + + long numberOfAssignedJobs = 0; + long numberOfAllocatingJobs = 0; + long assignedJobMemory = 0; + boolean allocateByMemory; + + CurrentLoad(boolean allocateByMemory) { + this.allocateByMemory = allocateByMemory; + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java index 29116b320a885..09997406955da 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java @@ -32,8 +32,11 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Phaser; import java.util.stream.Collectors; @@ -55,8 +58,9 @@ public class MlMemoryTracker implements LocalNodeMasterListener { private static final Duration RECENT_UPDATE_THRESHOLD = Duration.ofMinutes(1); private final Logger logger = LogManager.getLogger(MlMemoryTracker.class); - private final ConcurrentHashMap memoryRequirementByAnomalyDetectorJob = new ConcurrentHashMap<>(); - private final ConcurrentHashMap memoryRequirementByDataFrameAnalyticsJob = new ConcurrentHashMap<>(); + private final Map memoryRequirementByAnomalyDetectorJob = new ConcurrentHashMap<>(); + private final Map memoryRequirementByDataFrameAnalyticsJob = new ConcurrentHashMap<>(); + private final Map> memoryRequirementByTaskName; private final List> fullRefreshCompletionListeners = new ArrayList<>(); private final ThreadPool threadPool; @@ -77,6 +81,12 @@ public MlMemoryTracker(Settings settings, ClusterService clusterService, ThreadP this.jobResultsProvider = jobResultsProvider; this.configProvider = configProvider; this.stopPhaser = new Phaser(1); + + Map> memoryRequirementByTaskName = new TreeMap<>(); + memoryRequirementByTaskName.put(MlTasks.JOB_TASK_NAME, memoryRequirementByAnomalyDetectorJob); + memoryRequirementByTaskName.put(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryRequirementByDataFrameAnalyticsJob); + this.memoryRequirementByTaskName = Collections.unmodifiableMap(memoryRequirementByTaskName); + setReassignmentRecheckInterval(PersistentTasksClusterService.CLUSTER_TASKS_ALLOCATION_RECHECK_INTERVAL_SETTING.get(settings)); clusterService.addLocalNodeMasterListener(this); clusterService.getClusterSettings().addSettingsUpdateConsumer( @@ -97,8 +107,9 @@ public void onMaster() { public void offMaster() { isMaster = false; logger.trace("ML memory tracker off master"); - memoryRequirementByAnomalyDetectorJob.clear(); - memoryRequirementByDataFrameAnalyticsJob.clear(); + for (Map memoryRequirementByJob : memoryRequirementByTaskName.values()) { + memoryRequirementByJob.clear(); + } lastUpdateTime = null; } @@ -142,17 +153,7 @@ public boolean isRecentlyRefreshed() { * or null if it cannot be calculated. */ public Long getAnomalyDetectorJobMemoryRequirement(String jobId) { - - if (isMaster == false) { - return null; - } - - Long memoryRequirement = memoryRequirementByAnomalyDetectorJob.get(jobId); - if (memoryRequirement != null) { - return memoryRequirement; - } - - return null; + return getJobMemoryRequirement(MlTasks.JOB_TASK_NAME, jobId); } /** @@ -163,17 +164,29 @@ public Long getAnomalyDetectorJobMemoryRequirement(String jobId) { * or null if it cannot be found. */ public Long getDataFrameAnalyticsJobMemoryRequirement(String id) { + return getJobMemoryRequirement(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, id); + } + + /** + * Get the memory requirement for the type of job corresponding to a specified persistent task name. + * This method only works on the master node. + * @param taskName The persistent task name. + * @param id The job ID. + * @return The memory requirement of the job specified by {@code id}, + * or null if it cannot be found. + */ + public Long getJobMemoryRequirement(String taskName, String id) { if (isMaster == false) { return null; } - Long memoryRequirement = memoryRequirementByDataFrameAnalyticsJob.get(id); - if (memoryRequirement != null) { - return memoryRequirement; + Map memoryRequirementByJob = memoryRequirementByTaskName.get(taskName); + if (memoryRequirementByJob == null) { + return null; } - return null; + return memoryRequirementByJob.get(id); } /** @@ -255,7 +268,7 @@ public void addDataFrameAnalyticsJobMemoryAndRefreshAllOthers(String id, long me return; } - memoryRequirementByDataFrameAnalyticsJob.put(id, mem); + memoryRequirementByDataFrameAnalyticsJob.put(id, mem + DataFrameAnalyticsConfig.PROCESS_MEMORY_OVERHEAD.getBytes()); PersistentTasksCustomMetaData persistentTasks = clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE); refresh(persistentTasks, listener); @@ -334,7 +347,8 @@ private void refreshAllDataFrameAnalyticsJobTasks(List { for (DataFrameAnalyticsConfig analyticsConfig : analyticsConfigs) { - memoryRequirementByDataFrameAnalyticsJob.put(analyticsConfig.getId(), analyticsConfig.getModelMemoryLimit().getBytes()); + memoryRequirementByDataFrameAnalyticsJob.put(analyticsConfig.getId(), + analyticsConfig.getModelMemoryLimit().getBytes() + DataFrameAnalyticsConfig.PROCESS_MEMORY_OVERHEAD.getBytes()); } listener.onResponse(null); }, @@ -410,9 +424,9 @@ private void setAnomalyDetectorJobMemoryToLimit(String jobId, ActionListener { if (e instanceof ResourceNotFoundException) { // TODO: does this also happen if the .ml-config index exists but is unavailable? - logger.trace("[{}] job deleted during ML memory update", jobId); + logger.trace("[{}] anomaly detector job deleted during ML memory update", jobId); } else { - logger.error("[" + jobId + "] failed to get job during ML memory update", e); + logger.error("[" + jobId + "] failed to get anomaly detector job during ML memory update", e); } memoryRequirementByAnomalyDetectorJob.remove(jobId); listener.onResponse(null); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java index 040ed5e1d0ed4..cc9a0ba0181ad 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java @@ -16,8 +16,6 @@ import org.elasticsearch.cluster.metadata.IndexMetaData; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.MetaData; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.routing.IndexShardRoutingTable; import org.elasticsearch.cluster.routing.RecoverySource; @@ -27,9 +25,6 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.transport.TransportAddress; -import org.elasticsearch.common.unit.ByteSizeUnit; -import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.index.Index; import org.elasticsearch.index.shard.ShardId; @@ -56,38 +51,19 @@ import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager; import org.elasticsearch.xpack.ml.process.MlMemoryTracker; -import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; -import org.junit.Before; -import java.net.InetAddress; import java.util.ArrayList; import java.util.Collections; import java.util.Date; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; import static org.elasticsearch.xpack.core.ml.job.config.JobTests.buildJobBuilder; -import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -// TODO: in 8.0.0 remove all instances of MAX_OPEN_JOBS_NODE_ATTR from this file public class TransportOpenJobActionTests extends ESTestCase { - private MlMemoryTracker memoryTracker; - private boolean isMemoryTrackerRecentlyRefreshed; - - @Before - public void setup() { - memoryTracker = mock(MlMemoryTracker.class); - isMemoryTrackerRecentlyRefreshed = true; - when(memoryTracker.isRecentlyRefreshed()).thenReturn(isMemoryTrackerRecentlyRefreshed); - } - public void testValidate_jobMissing() { expectThrows(ResourceNotFoundException.class, () -> TransportOpenJobAction.validate("job_id2", null)); } @@ -113,347 +89,6 @@ public void testValidate_givenValidJob() { TransportOpenJobAction.validate("job_id", jobBuilder.build(new Date())); } - public void testSelectLeastLoadedMlNode_byCount() { - Map nodeAttr = new HashMap<>(); - nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); - nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "-1"); - // MachineLearning.MACHINE_MEMORY_NODE_ATTR negative, so this will fall back to allocating by count - DiscoveryNodes nodes = DiscoveryNodes.builder() - .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .add(new DiscoveryNode("_node_name3", "_node_id3", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .build(); - - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - addJobTask("job_id1", "_node_id1", null, tasksBuilder); - addJobTask("job_id2", "_node_id1", null, tasksBuilder); - addJobTask("job_id3", "_node_id2", null, tasksBuilder); - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); - cs.nodes(nodes); - MetaData.Builder metaData = MetaData.builder(); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - cs.metaData(metaData); - - Job.Builder jobBuilder = buildJobBuilder("job_id4"); - jobBuilder.setJobVersion(Version.CURRENT); - - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id4", jobBuilder.build(), - cs.build(), 10, 2, 30, memoryTracker, isMemoryTrackerRecentlyRefreshed, logger); - assertEquals("", result.getExplanation()); - assertEquals("_node_id3", result.getExecutorNode()); - } - - public void testSelectLeastLoadedMlNode_maxCapacity() { - int numNodes = randomIntBetween(1, 10); - int maxRunningJobsPerNode = randomIntBetween(1, 100); - - Map nodeAttr = new HashMap<>(); - nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, Integer.toString(maxRunningJobsPerNode)); - nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); - DiscoveryNodes.Builder nodes = DiscoveryNodes.builder(); - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - String[] jobIds = new String[numNodes * maxRunningJobsPerNode]; - for (int i = 0; i < numNodes; i++) { - String nodeId = "_node_id" + i; - TransportAddress address = new TransportAddress(InetAddress.getLoopbackAddress(), 9300 + i); - nodes.add(new DiscoveryNode("_node_name" + i, nodeId, address, nodeAttr, Collections.emptySet(), Version.CURRENT)); - for (int j = 0; j < maxRunningJobsPerNode; j++) { - int id = j + (maxRunningJobsPerNode * i); - jobIds[id] = "job_id" + id; - addJobTask(jobIds[id], nodeId, JobState.OPENED, tasksBuilder); - } - } - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); - MetaData.Builder metaData = MetaData.builder(); - cs.nodes(nodes); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - cs.metaData(metaData); - - Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id0", new ByteSizeValue(150, ByteSizeUnit.MB)).build(new Date()); - - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id0", job, cs.build(), maxRunningJobsPerNode, 2, - 30, memoryTracker, isMemoryTrackerRecentlyRefreshed, logger); - assertNull(result.getExecutorNode()); - assertTrue(result.getExplanation(), result.getExplanation().contains("because this node is full. Number of opened jobs [" - + maxRunningJobsPerNode + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]")); - } - - public void testSelectLeastLoadedMlNode_noMlNodes() { - DiscoveryNodes nodes = DiscoveryNodes.builder() - .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - Collections.emptyMap(), Collections.emptySet(), Version.CURRENT)) - .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), - Collections.emptyMap(), Collections.emptySet(), Version.CURRENT)) - .build(); - - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - addJobTask("job_id1", "_node_id1", null, tasksBuilder); - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); - MetaData.Builder metaData = MetaData.builder(); - cs.nodes(nodes); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - cs.metaData(metaData); - - Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id2", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); - - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id2", job, cs.build(), 20, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertTrue(result.getExplanation().contains("because this node isn't a ml node")); - assertNull(result.getExecutorNode()); - } - - public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { - Map nodeAttr = new HashMap<>(); - nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); - nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); - DiscoveryNodes nodes = DiscoveryNodes.builder() - .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .add(new DiscoveryNode("_node_name3", "_node_id3", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .build(); - - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - addJobTask("job_id1", "_node_id1", null, tasksBuilder); - addJobTask("job_id2", "_node_id1", null, tasksBuilder); - addJobTask("job_id3", "_node_id2", null, tasksBuilder); - addJobTask("job_id4", "_node_id2", null, tasksBuilder); - addJobTask("job_id5", "_node_id3", null, tasksBuilder); - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name")); - csBuilder.nodes(nodes); - MetaData.Builder metaData = MetaData.builder(); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - csBuilder.metaData(metaData); - - Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id6", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); - - ClusterState cs = csBuilder.build(); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id6", job, cs, 10, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertEquals("_node_id3", result.getExecutorNode()); - - tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); - addJobTask("job_id6", "_node_id3", null, tasksBuilder); - tasks = tasksBuilder.build(); - - csBuilder = ClusterState.builder(cs); - csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); - cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 10, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertNull("no node selected, because OPENING state", result.getExecutorNode()); - assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); - - tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); - tasksBuilder.reassignTask(MlTasks.jobTaskId("job_id6"), new Assignment("_node_id3", "test assignment")); - tasks = tasksBuilder.build(); - - csBuilder = ClusterState.builder(cs); - csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); - cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 10, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertNull("no node selected, because stale task", result.getExecutorNode()); - assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); - - tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); - tasksBuilder.updateTaskState(MlTasks.jobTaskId("job_id6"), null); - tasks = tasksBuilder.build(); - - csBuilder = ClusterState.builder(cs); - csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); - cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 10, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertNull("no node selected, because null state", result.getExecutorNode()); - assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); - } - - public void testSelectLeastLoadedMlNode_concurrentOpeningJobsAndStaleFailedJob() { - Map nodeAttr = new HashMap<>(); - nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); - nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); - DiscoveryNodes nodes = DiscoveryNodes.builder() - .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .add(new DiscoveryNode("_node_name3", "_node_id3", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .build(); - - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - addJobTask("job_id1", "_node_id1", JobState.fromString("failed"), tasksBuilder); - // This will make the allocation stale for job_id1 - tasksBuilder.reassignTask(MlTasks.jobTaskId("job_id1"), new Assignment("_node_id1", "test assignment")); - addJobTask("job_id2", "_node_id1", null, tasksBuilder); - addJobTask("job_id3", "_node_id2", null, tasksBuilder); - addJobTask("job_id4", "_node_id2", null, tasksBuilder); - addJobTask("job_id5", "_node_id3", null, tasksBuilder); - addJobTask("job_id6", "_node_id3", null, tasksBuilder); - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name")); - csBuilder.nodes(nodes); - MetaData.Builder metaData = MetaData.builder(); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - csBuilder.metaData(metaData); - - ClusterState cs = csBuilder.build(); - Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id7", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); - - // Allocation won't be possible if the stale failed job is treated as opening - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 10, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertEquals("_node_id1", result.getExecutorNode()); - - tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); - addJobTask("job_id7", "_node_id1", null, tasksBuilder); - tasks = tasksBuilder.build(); - - csBuilder = ClusterState.builder(cs); - csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); - cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id8", job, cs, 10, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertNull("no node selected, because OPENING state", result.getExecutorNode()); - assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); - } - - public void testSelectLeastLoadedMlNode_noCompatibleJobTypeNodes() { - Map nodeAttr = new HashMap<>(); - nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); - nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); - DiscoveryNodes nodes = DiscoveryNodes.builder() - .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), - nodeAttr, Collections.emptySet(), Version.CURRENT)) - .build(); - - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - addJobTask("incompatible_type_job", "_node_id1", null, tasksBuilder); - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); - MetaData.Builder metaData = MetaData.builder(); - - Job job = mock(Job.class); - when(job.getId()).thenReturn("incompatible_type_job"); - when(job.getJobVersion()).thenReturn(Version.CURRENT); - when(job.getJobType()).thenReturn("incompatible_type"); - when(job.getInitialResultsIndexName()).thenReturn("shared"); - - cs.nodes(nodes); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - cs.metaData(metaData); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 10, 2, 30, - memoryTracker, isMemoryTrackerRecentlyRefreshed, logger); - assertThat(result.getExplanation(), containsString("because this node does not support jobs of type [incompatible_type]")); - assertNull(result.getExecutorNode()); - } - - public void testSelectLeastLoadedMlNode_noNodesMatchingModelSnapshotMinVersion() { - Map nodeAttr = new HashMap<>(); - nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); - nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); - DiscoveryNodes nodes = DiscoveryNodes.builder() - .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - nodeAttr, Collections.emptySet(), Version.V_6_2_0)) - .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), - nodeAttr, Collections.emptySet(), Version.V_6_1_0)) - .build(); - - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - addJobTask("job_with_incompatible_model_snapshot", "_node_id1", null, tasksBuilder); - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); - MetaData.Builder metaData = MetaData.builder(); - - Job job = BaseMlIntegTestCase.createFareQuoteJob("job_with_incompatible_model_snapshot") - .setModelSnapshotId("incompatible_snapshot") - .setModelSnapshotMinVersion(Version.V_6_3_0) - .build(new Date()); - cs.nodes(nodes); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - cs.metaData(metaData); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_incompatible_model_snapshot", job, cs.build(), 10, - 2, 30, memoryTracker, isMemoryTrackerRecentlyRefreshed, logger); - assertThat(result.getExplanation(), containsString( - "because the job's model snapshot requires a node of version [6.3.0] or higher")); - assertNull(result.getExecutorNode()); - } - - public void testSelectLeastLoadedMlNode_jobWithRulesButNoNodeMeetsRequiredVersion() { - Map nodeAttr = new HashMap<>(); - nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); - nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); - DiscoveryNodes nodes = DiscoveryNodes.builder() - .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - nodeAttr, Collections.emptySet(), Version.V_6_2_0)) - .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), - nodeAttr, Collections.emptySet(), Version.V_6_3_0)) - .build(); - - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - addJobTask("job_with_rules", "_node_id1", null, tasksBuilder); - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); - MetaData.Builder metaData = MetaData.builder(); - cs.nodes(nodes); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - cs.metaData(metaData); - - Job job = jobWithRules("job_with_rules"); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 10, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertThat(result.getExplanation(), containsString( - "because jobs using custom_rules require a node of version [6.4.0] or higher")); - assertNull(result.getExecutorNode()); - } - - public void testSelectLeastLoadedMlNode_jobWithRulesAndNodeMeetsRequiredVersion() { - Map nodeAttr = new HashMap<>(); - nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); - nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); - DiscoveryNodes nodes = DiscoveryNodes.builder() - .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), - nodeAttr, Collections.emptySet(), Version.V_6_2_0)) - .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), - nodeAttr, Collections.emptySet(), Version.V_6_4_0)) - .build(); - - PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); - addJobTask("job_with_rules", "_node_id1", null, tasksBuilder); - PersistentTasksCustomMetaData tasks = tasksBuilder.build(); - - ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); - MetaData.Builder metaData = MetaData.builder(); - cs.nodes(nodes); - metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); - cs.metaData(metaData); - - Job job = jobWithRules("job_with_rules"); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 10, 2, 30, memoryTracker, - isMemoryTrackerRecentlyRefreshed, logger); - assertNotNull(result.getExecutorNode()); - } - public void testVerifyIndicesPrimaryShardsAreActive() { MetaData.Builder metaData = MetaData.builder(); RoutingTable.Builder routingTable = RoutingTable.builder(); @@ -490,33 +125,6 @@ public void testVerifyIndicesPrimaryShardsAreActive() { assertEquals(indexToRemove, result.get(0)); } - public void testNodeNameAndVersion() { - TransportAddress ta = new TransportAddress(InetAddress.getLoopbackAddress(), 9300); - Map attributes = new HashMap<>(); - attributes.put("unrelated", "attribute"); - DiscoveryNode node = new DiscoveryNode("_node_name1", "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); - assertEquals("{_node_name1}{version=" + node.getVersion() + "}", TransportOpenJobAction.nodeNameAndVersion(node)); - } - - public void testNodeNameAndMlAttributes() { - TransportAddress ta = new TransportAddress(InetAddress.getLoopbackAddress(), 9300); - SortedMap attributes = new TreeMap<>(); - attributes.put("unrelated", "attribute"); - DiscoveryNode node = new DiscoveryNode("_node_name1", "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); - assertEquals("{_node_name1}", TransportOpenJobAction.nodeNameAndMlAttributes(node)); - - attributes.put("ml.machine_memory", "5"); - node = new DiscoveryNode("_node_name1", "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); - assertEquals("{_node_name1}{ml.machine_memory=5}", TransportOpenJobAction.nodeNameAndMlAttributes(node)); - - node = new DiscoveryNode(null, "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); - assertEquals("{_node_id1}{ml.machine_memory=5}", TransportOpenJobAction.nodeNameAndMlAttributes(node)); - - attributes.put("node.ml", "true"); - node = new DiscoveryNode("_node_name1", "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); - assertEquals("{_node_name1}{ml.machine_memory=5}{node.ml=true}", TransportOpenJobAction.nodeNameAndMlAttributes(node)); - } - public void testJobTaskMatcherMatch() { Task nonJobTask1 = mock(Task.class); Task nonJobTask2 = mock(Task.class); @@ -620,7 +228,7 @@ private void addIndices(MetaData.Builder metaData, RoutingTable.Builder routingT } } - private static Job jobWithRules(String jobId) { + public static Job jobWithRules(String jobId) { DetectionRule rule = new DetectionRule.Builder(Collections.singletonList( new RuleCondition(RuleCondition.AppliesTo.TYPICAL, Operator.LT, 100.0) )).build(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java new file mode 100644 index 0000000000000..4986029a839e7 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java @@ -0,0 +1,575 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.job; + +import org.elasticsearch.Version; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ml.job.config.JobState; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.action.TransportOpenJobAction; +import org.elasticsearch.xpack.ml.action.TransportOpenJobActionTests; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; +import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; +import org.junit.Before; + +import java.net.InetAddress; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import static org.elasticsearch.xpack.core.ml.job.config.JobTests.buildJobBuilder; +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +// TODO: in 8.0.0 remove all instances of MAX_OPEN_JOBS_NODE_ATTR from this file +public class JobNodeSelectorTests extends ESTestCase { + + // To simplify the the logic in this class all jobs have the same memory requirement + private static final ByteSizeValue JOB_MEMORY_REQUIREMENT = new ByteSizeValue(10, ByteSizeUnit.MB); + + private MlMemoryTracker memoryTracker; + private boolean isMemoryTrackerRecentlyRefreshed; + + @Before + public void setup() { + memoryTracker = mock(MlMemoryTracker.class); + isMemoryTrackerRecentlyRefreshed = true; + when(memoryTracker.isRecentlyRefreshed()).thenReturn(isMemoryTrackerRecentlyRefreshed); + when(memoryTracker.getAnomalyDetectorJobMemoryRequirement(anyString())).thenReturn(JOB_MEMORY_REQUIREMENT.getBytes()); + when(memoryTracker.getDataFrameAnalyticsJobMemoryRequirement(anyString())).thenReturn(JOB_MEMORY_REQUIREMENT.getBytes()); + when(memoryTracker.getJobMemoryRequirement(anyString(), anyString())).thenReturn(JOB_MEMORY_REQUIREMENT.getBytes()); + } + + public void testNodeNameAndVersion() { + TransportAddress ta = new TransportAddress(InetAddress.getLoopbackAddress(), 9300); + Map attributes = new HashMap<>(); + attributes.put("unrelated", "attribute"); + DiscoveryNode node = new DiscoveryNode("_node_name1", "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); + assertEquals("{_node_name1}{version=" + node.getVersion() + "}", JobNodeSelector.nodeNameAndVersion(node)); + } + + public void testNodeNameAndMlAttributes() { + TransportAddress ta = new TransportAddress(InetAddress.getLoopbackAddress(), 9300); + SortedMap attributes = new TreeMap<>(); + attributes.put("unrelated", "attribute"); + DiscoveryNode node = new DiscoveryNode("_node_name1", "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); + assertEquals("{_node_name1}", JobNodeSelector.nodeNameAndMlAttributes(node)); + + attributes.put("ml.machine_memory", "5"); + node = new DiscoveryNode("_node_name1", "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); + assertEquals("{_node_name1}{ml.machine_memory=5}", JobNodeSelector.nodeNameAndMlAttributes(node)); + + node = new DiscoveryNode(null, "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); + assertEquals("{_node_id1}{ml.machine_memory=5}", JobNodeSelector.nodeNameAndMlAttributes(node)); + + attributes.put("node.ml", "true"); + node = new DiscoveryNode("_node_name1", "_node_id1", ta, attributes, Collections.emptySet(), Version.CURRENT); + assertEquals("{_node_name1}{ml.machine_memory=5}{node.ml=true}", JobNodeSelector.nodeNameAndMlAttributes(node)); + } + + public void testSelectLeastLoadedMlNode_byCount() { + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "-1"); + // MachineLearning.MACHINE_MEMORY_NODE_ATTR negative, so this will fall back to allocating by count + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name3", "_node_id3", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .build(); + + PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); + TransportOpenJobActionTests.addJobTask("job_id1", "_node_id1", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id2", "_node_id1", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id3", "_node_id2", null, tasksBuilder); + PersistentTasksCustomMetaData tasks = tasksBuilder.build(); + + ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); + cs.nodes(nodes); + MetaData.Builder metaData = MetaData.builder(); + metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); + cs.metaData(metaData); + + Job.Builder jobBuilder = buildJobBuilder("job_id4"); + jobBuilder.setJobVersion(Version.CURRENT); + + Job job = jobBuilder.build(); + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertEquals("", result.getExplanation()); + assertEquals("_node_id3", result.getExecutorNode()); + } + + public void testSelectLeastLoadedMlNodeForAnomalyDetectorJob_maxCapacityCountLimiting() { + int numNodes = randomIntBetween(1, 10); + int maxRunningJobsPerNode = randomIntBetween(1, 100); + int maxMachineMemoryPercent = 30; + long machineMemory = (maxRunningJobsPerNode + 1) * JOB_MEMORY_REQUIREMENT.getBytes() * 100 / maxMachineMemoryPercent; + + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, Integer.toString(maxRunningJobsPerNode)); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, Long.toString(machineMemory)); + + ClusterState.Builder cs = fillNodesWithRunningJobs(nodeAttr, numNodes, maxRunningJobsPerNode); + + Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date()); + + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = + jobNodeSelector.selectNode(maxRunningJobsPerNode, 2, maxMachineMemoryPercent, isMemoryTrackerRecentlyRefreshed); + assertNull(result.getExecutorNode()); + assertThat(result.getExplanation(), containsString("because this node is full. Number of opened jobs [" + + maxRunningJobsPerNode + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]")); + } + + public void testSelectLeastLoadedMlNodeForDataFrameAnalyticsJob_maxCapacityCountLimiting() { + int numNodes = randomIntBetween(1, 10); + int maxRunningJobsPerNode = randomIntBetween(1, 100); + int maxMachineMemoryPercent = 30; + long machineMemory = (maxRunningJobsPerNode + 1) * JOB_MEMORY_REQUIREMENT.getBytes() * 100 / maxMachineMemoryPercent; + + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, Integer.toString(maxRunningJobsPerNode)); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, Long.toString(machineMemory)); + + ClusterState.Builder cs = fillNodesWithRunningJobs(nodeAttr, numNodes, maxRunningJobsPerNode); + + String dataFrameAnalyticsId = "data_frame_analytics_id1000"; + + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), dataFrameAnalyticsId, + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryTracker, 0, + node -> TransportStartDataFrameAnalyticsAction.TaskExecutor.nodeFilter(node, dataFrameAnalyticsId)); + PersistentTasksCustomMetaData.Assignment result = + jobNodeSelector.selectNode(maxRunningJobsPerNode, 2, maxMachineMemoryPercent, isMemoryTrackerRecentlyRefreshed); + assertNull(result.getExecutorNode()); + assertThat(result.getExplanation(), containsString("because this node is full. Number of opened jobs [" + + maxRunningJobsPerNode + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]")); + } + + public void testSelectLeastLoadedMlNodeForAnomalyDetectorJob_maxCapacityMemoryLimiting() { + int numNodes = randomIntBetween(1, 10); + int currentlyRunningJobsPerNode = randomIntBetween(1, 100); + int maxRunningJobsPerNode = currentlyRunningJobsPerNode + 1; + // Be careful if changing this - in order for the error message to be exactly as expected + // the value here must divide exactly into (JOB_MEMORY_REQUIREMENT.getBytes() * 100) + int maxMachineMemoryPercent = 40; + long machineMemory = currentlyRunningJobsPerNode * JOB_MEMORY_REQUIREMENT.getBytes() * 100 / maxMachineMemoryPercent; + + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, Integer.toString(maxRunningJobsPerNode)); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, Long.toString(machineMemory)); + + ClusterState.Builder cs = fillNodesWithRunningJobs(nodeAttr, numNodes, currentlyRunningJobsPerNode); + + Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date()); + + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = + jobNodeSelector.selectNode(maxRunningJobsPerNode, 2, maxMachineMemoryPercent, isMemoryTrackerRecentlyRefreshed); + assertNull(result.getExecutorNode()); + assertThat(result.getExplanation(), containsString("because this node has insufficient available memory. " + + "Available memory for ML [" + (machineMemory * maxMachineMemoryPercent / 100) + "], memory required by existing jobs [" + + (JOB_MEMORY_REQUIREMENT.getBytes() * currentlyRunningJobsPerNode) + "], estimated memory required for this job [" + + JOB_MEMORY_REQUIREMENT.getBytes() + "]")); + } + + public void testSelectLeastLoadedMlNodeForDataFrameAnalyticsJob_maxCapacityMemoryLimiting() { + int numNodes = randomIntBetween(1, 10); + int currentlyRunningJobsPerNode = randomIntBetween(1, 100); + int maxRunningJobsPerNode = currentlyRunningJobsPerNode + 1; + // Be careful if changing this - in order for the error message to be exactly as expected + // the value here must divide exactly into (JOB_MEMORY_REQUIREMENT.getBytes() * 100) + int maxMachineMemoryPercent = 40; + long machineMemory = currentlyRunningJobsPerNode * JOB_MEMORY_REQUIREMENT.getBytes() * 100 / maxMachineMemoryPercent; + + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, Integer.toString(maxRunningJobsPerNode)); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, Long.toString(machineMemory)); + + ClusterState.Builder cs = fillNodesWithRunningJobs(nodeAttr, numNodes, currentlyRunningJobsPerNode); + + String dataFrameAnalyticsId = "data_frame_analytics_id1000"; + + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), dataFrameAnalyticsId, + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, memoryTracker, 0, + node -> TransportStartDataFrameAnalyticsAction.TaskExecutor.nodeFilter(node, dataFrameAnalyticsId)); + PersistentTasksCustomMetaData.Assignment result = + jobNodeSelector.selectNode(maxRunningJobsPerNode, 2, maxMachineMemoryPercent, isMemoryTrackerRecentlyRefreshed); + assertNull(result.getExecutorNode()); + assertThat(result.getExplanation(), containsString("because this node has insufficient available memory. " + + "Available memory for ML [" + (machineMemory * maxMachineMemoryPercent / 100) + "], memory required by existing jobs [" + + (JOB_MEMORY_REQUIREMENT.getBytes() * currentlyRunningJobsPerNode) + "], estimated memory required for this job [" + + JOB_MEMORY_REQUIREMENT.getBytes() + "]")); + } + + public void testSelectLeastLoadedMlNode_noMlNodes() { + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + Collections.emptyMap(), Collections.emptySet(), Version.CURRENT)) + .build(); + + PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); + TransportOpenJobActionTests.addJobTask("job_id1", "_node_id1", null, tasksBuilder); + PersistentTasksCustomMetaData tasks = tasksBuilder.build(); + + ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); + MetaData.Builder metaData = MetaData.builder(); + cs.nodes(nodes); + metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); + cs.metaData(metaData); + + Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id2", JOB_MEMORY_REQUIREMENT).build(new Date()); + + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = jobNodeSelector.selectNode(20, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertTrue(result.getExplanation().contains("because this node isn't a ml node")); + assertNull(result.getExecutorNode()); + } + + public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name3", "_node_id3", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .build(); + + PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); + TransportOpenJobActionTests.addJobTask("job_id1", "_node_id1", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id2", "_node_id1", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id3", "_node_id2", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id4", "_node_id2", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id5", "_node_id3", null, tasksBuilder); + PersistentTasksCustomMetaData tasks = tasksBuilder.build(); + + ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name")); + csBuilder.nodes(nodes); + MetaData.Builder metaData = MetaData.builder(); + metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); + csBuilder.metaData(metaData); + + Job job6 = BaseMlIntegTestCase.createFareQuoteJob("job_id6", JOB_MEMORY_REQUIREMENT).build(new Date()); + + ClusterState cs = csBuilder.build(); + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs, job6.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job6)); + PersistentTasksCustomMetaData.Assignment result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertEquals("_node_id3", result.getExecutorNode()); + + tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); + TransportOpenJobActionTests.addJobTask(job6.getId(), "_node_id3", null, tasksBuilder); + tasks = tasksBuilder.build(); + + csBuilder = ClusterState.builder(cs); + csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); + cs = csBuilder.build(); + + Job job7 = BaseMlIntegTestCase.createFareQuoteJob("job_id7", JOB_MEMORY_REQUIREMENT).build(new Date()); + jobNodeSelector = new JobNodeSelector(cs, job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job7)); + result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertNull("no node selected, because OPENING state", result.getExecutorNode()); + assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); + + tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); + tasksBuilder.reassignTask(MlTasks.jobTaskId(job6.getId()), + new PersistentTasksCustomMetaData.Assignment("_node_id3", "test assignment")); + tasks = tasksBuilder.build(); + + csBuilder = ClusterState.builder(cs); + csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); + cs = csBuilder.build(); + jobNodeSelector = new JobNodeSelector(cs, job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job7)); + result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertNull("no node selected, because stale task", result.getExecutorNode()); + assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); + + tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); + tasksBuilder.updateTaskState(MlTasks.jobTaskId(job6.getId()), null); + tasks = tasksBuilder.build(); + + csBuilder = ClusterState.builder(cs); + csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); + cs = csBuilder.build(); + jobNodeSelector = new JobNodeSelector(cs, job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job7)); + result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertNull("no node selected, because null state", result.getExecutorNode()); + assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); + } + + public void testSelectLeastLoadedMlNode_concurrentOpeningJobsAndStaleFailedJob() { + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name3", "_node_id3", new TransportAddress(InetAddress.getLoopbackAddress(), 9302), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .build(); + + PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); + TransportOpenJobActionTests.addJobTask("job_id1", "_node_id1", JobState.fromString("failed"), tasksBuilder); + // This will make the allocation stale for job_id1 + tasksBuilder.reassignTask(MlTasks.jobTaskId("job_id1"), + new PersistentTasksCustomMetaData.Assignment("_node_id1", "test assignment")); + TransportOpenJobActionTests.addJobTask("job_id2", "_node_id1", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id3", "_node_id2", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id4", "_node_id2", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id5", "_node_id3", null, tasksBuilder); + TransportOpenJobActionTests.addJobTask("job_id6", "_node_id3", null, tasksBuilder); + PersistentTasksCustomMetaData tasks = tasksBuilder.build(); + + ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name")); + csBuilder.nodes(nodes); + MetaData.Builder metaData = MetaData.builder(); + metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); + csBuilder.metaData(metaData); + + ClusterState cs = csBuilder.build(); + Job job7 = BaseMlIntegTestCase.createFareQuoteJob("job_id7", JOB_MEMORY_REQUIREMENT).build(new Date()); + + // Allocation won't be possible if the stale failed job is treated as opening + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs, job7.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job7)); + PersistentTasksCustomMetaData.Assignment result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertEquals("_node_id1", result.getExecutorNode()); + + tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); + TransportOpenJobActionTests.addJobTask("job_id7", "_node_id1", null, tasksBuilder); + tasks = tasksBuilder.build(); + + csBuilder = ClusterState.builder(cs); + csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); + cs = csBuilder.build(); + Job job8 = BaseMlIntegTestCase.createFareQuoteJob("job_id8", JOB_MEMORY_REQUIREMENT).build(new Date()); + jobNodeSelector = new JobNodeSelector(cs, job8.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job8)); + result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertNull("no node selected, because OPENING state", result.getExecutorNode()); + assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); + } + + public void testSelectLeastLoadedMlNode_noCompatibleJobTypeNodes() { + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + nodeAttr, Collections.emptySet(), Version.CURRENT)) + .build(); + + PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); + TransportOpenJobActionTests.addJobTask("incompatible_type_job", "_node_id1", null, tasksBuilder); + PersistentTasksCustomMetaData tasks = tasksBuilder.build(); + + ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); + MetaData.Builder metaData = MetaData.builder(); + + Job job = mock(Job.class); + when(job.getId()).thenReturn("incompatible_type_job"); + when(job.getJobVersion()).thenReturn(Version.CURRENT); + when(job.getJobType()).thenReturn("incompatible_type"); + when(job.getInitialResultsIndexName()).thenReturn("shared"); + + cs.nodes(nodes); + metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); + cs.metaData(metaData); + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertThat(result.getExplanation(), containsString("because this node does not support jobs of type [incompatible_type]")); + assertNull(result.getExecutorNode()); + } + + public void testSelectLeastLoadedMlNode_noNodesMatchingModelSnapshotMinVersion() { + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + nodeAttr, Collections.emptySet(), Version.V_6_2_0)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + nodeAttr, Collections.emptySet(), Version.V_6_1_0)) + .build(); + + PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); + TransportOpenJobActionTests.addJobTask("job_with_incompatible_model_snapshot", "_node_id1", null, tasksBuilder); + PersistentTasksCustomMetaData tasks = tasksBuilder.build(); + + ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); + MetaData.Builder metaData = MetaData.builder(); + + Job job = BaseMlIntegTestCase.createFareQuoteJob("job_with_incompatible_model_snapshot") + .setModelSnapshotId("incompatible_snapshot") + .setModelSnapshotMinVersion(Version.V_6_3_0) + .build(new Date()); + cs.nodes(nodes); + metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); + cs.metaData(metaData); + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), + MlTasks.JOB_TASK_NAME, memoryTracker, 0, node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertThat(result.getExplanation(), containsString( + "because the job's model snapshot requires a node of version [6.3.0] or higher")); + assertNull(result.getExecutorNode()); + } + + public void testSelectLeastLoadedMlNode_jobWithRules() { + Map nodeAttr = new HashMap<>(); + nodeAttr.put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, "10"); + nodeAttr.put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, "1000000000"); + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + nodeAttr, Collections.emptySet(), Version.V_6_2_0)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + nodeAttr, Collections.emptySet(), Version.V_6_4_0)) + .build(); + + PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); + TransportOpenJobActionTests.addJobTask("job_with_rules", "_node_id1", null, tasksBuilder); + PersistentTasksCustomMetaData tasks = tasksBuilder.build(); + + ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); + MetaData.Builder metaData = MetaData.builder(); + cs.nodes(nodes); + metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); + cs.metaData(metaData); + + Job job = TransportOpenJobActionTests.jobWithRules("job_with_rules"); + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = jobNodeSelector.selectNode(10, 2, 30, isMemoryTrackerRecentlyRefreshed); + assertNotNull(result.getExecutorNode()); + } + + public void testConsiderLazyAssignmentWithNoLazyNodes() { + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + Collections.emptyMap(), Collections.emptySet(), Version.CURRENT)) + .build(); + + ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); + cs.nodes(nodes); + + Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date()); + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, 0, + node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = + jobNodeSelector.considerLazyAssignment(new PersistentTasksCustomMetaData.Assignment(null, "foo")); + assertEquals("foo", result.getExplanation()); + assertNull(result.getExecutorNode()); + } + + public void testConsiderLazyAssignmentWithLazyNodes() { + DiscoveryNodes nodes = DiscoveryNodes.builder() + .add(new DiscoveryNode("_node_name1", "_node_id1", new TransportAddress(InetAddress.getLoopbackAddress(), 9300), + Collections.emptyMap(), Collections.emptySet(), Version.CURRENT)) + .add(new DiscoveryNode("_node_name2", "_node_id2", new TransportAddress(InetAddress.getLoopbackAddress(), 9301), + Collections.emptyMap(), Collections.emptySet(), Version.CURRENT)) + .build(); + + ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); + cs.nodes(nodes); + + Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id1000", JOB_MEMORY_REQUIREMENT).build(new Date()); + JobNodeSelector jobNodeSelector = new JobNodeSelector(cs.build(), job.getId(), MlTasks.JOB_TASK_NAME, memoryTracker, + randomIntBetween(1, 3), node -> TransportOpenJobAction.nodeFilter(node, job)); + PersistentTasksCustomMetaData.Assignment result = + jobNodeSelector.considerLazyAssignment(new PersistentTasksCustomMetaData.Assignment(null, "foo")); + assertEquals(JobNodeSelector.AWAITING_LAZY_ASSIGNMENT.getExplanation(), result.getExplanation()); + assertNull(result.getExecutorNode()); + } + + private ClusterState.Builder fillNodesWithRunningJobs(Map nodeAttr, int numNodes, int numRunningJobsPerNode) { + + DiscoveryNodes.Builder nodes = DiscoveryNodes.builder(); + PersistentTasksCustomMetaData.Builder tasksBuilder = PersistentTasksCustomMetaData.builder(); + String[] jobIds = new String[numNodes * numRunningJobsPerNode]; + for (int i = 0; i < numNodes; i++) { + String nodeId = "_node_id" + i; + TransportAddress address = new TransportAddress(InetAddress.getLoopbackAddress(), 9300 + i); + nodes.add(new DiscoveryNode("_node_name" + i, nodeId, address, nodeAttr, Collections.emptySet(), Version.CURRENT)); + for (int j = 0; j < numRunningJobsPerNode; j++) { + int id = j + (numRunningJobsPerNode * i); + // Both anomaly detector jobs and data frame analytics jobs should count towards the limit + if (randomBoolean()) { + jobIds[id] = "job_id" + id; + TransportOpenJobActionTests.addJobTask(jobIds[id], nodeId, JobState.OPENED, tasksBuilder); + } else { + jobIds[id] = "data_frame_analytics_id" + id; + addDataFrameAnalyticsJobTask(jobIds[id], nodeId, DataFrameAnalyticsState.STARTED, tasksBuilder); + } + } + } + PersistentTasksCustomMetaData tasks = tasksBuilder.build(); + + ClusterState.Builder cs = ClusterState.builder(new ClusterName("_name")); + MetaData.Builder metaData = MetaData.builder(); + cs.nodes(nodes); + metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); + cs.metaData(metaData); + + return cs; + } + + static void addDataFrameAnalyticsJobTask(String id, String nodeId, DataFrameAnalyticsState state, + PersistentTasksCustomMetaData.Builder builder) { + addDataFrameAnalyticsJobTask(id, nodeId, state, builder, false); + } + + static void addDataFrameAnalyticsJobTask(String id, String nodeId, DataFrameAnalyticsState state, + PersistentTasksCustomMetaData.Builder builder, boolean isStale) { + builder.addTask(MlTasks.dataFrameAnalyticsTaskId(id), MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, + new StartDataFrameAnalyticsAction.TaskParams(id), new PersistentTasksCustomMetaData.Assignment(nodeId, "test assignment")); + if (state != null) { + builder.updateTaskState(MlTasks.dataFrameAnalyticsTaskId(id), + new DataFrameAnalyticsTaskState(state, builder.getLastAllocationId() - (isStale ? 1 : 0))); + } + } +} From b2bcaf2e03f9aacf111cde3a706612e9ea7851d2 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 16 Apr 2019 09:47:07 +0300 Subject: [PATCH 34/67] [FEATURE][ML] Change df-analytics config to accept single analysis (#41188) This changes the config to have a single analysis using the names writeables framework. `analyses_fields` are also renamed to `analyzed_fields`. Finally, during this work it was revealed that named writeables should not be declared from the `MachineLearning` plugin but in the `XPackClientPlugin` instead, so this is also done for the evaluation named writeables. Note that once the transport client is removed they will move back into the `MachineLearning` plugin. --- .../xpack/core/XPackClientPlugin.java | 29 ++++ .../ml/dataframe/DataFrameAnalysisConfig.java | 68 -------- .../dataframe/DataFrameAnalyticsConfig.java | 95 +++++------ .../dataframe/analyses/DataFrameAnalysis.java | 16 ++ ...ataFrameAnalysisNamedXContentProvider.java | 37 +++++ .../dataframe/analyses/OutlierDetection.java | 148 ++++++++++++++++++ .../persistence/ElasticsearchMappings.java | 11 +- .../ml/job/results/ReservedFieldNames.java | 11 +- ...DataFrameAnalyticsActionResponseTests.java | 13 +- ...tDataFrameAnalyticsActionRequestTests.java | 15 +- ...DataFrameAnalyticsActionResponseTests.java | 15 +- .../DataFrameAnalysisConfigTests.java | 47 ------ .../DataFrameAnalyticsConfigTests.java | 24 +-- .../analyses/OutlierDetectionTests.java | 36 +++++ .../ml/qa/ml-with-security/build.gradle | 5 +- .../integration/RunDataFrameAnalyticsIT.java | 8 +- .../xpack/ml/MachineLearning.java | 13 +- .../TransportPutDataFrameAnalyticsAction.java | 2 - .../analyses/AbstractDataFrameAnalysis.java | 28 ---- .../analyses/DataFrameAnalysesUtils.java | 80 ---------- .../dataframe/analyses/DataFrameAnalysis.java | 47 ------ .../dataframe/analyses/OutlierDetection.java | 64 -------- .../extractor/ExtractedFieldsDetector.java | 8 +- .../process/AnalyticsProcessConfig.java | 22 ++- .../process/AnalyticsProcessManager.java | 8 +- .../analyses/DataFrameAnalysesUtilsTests.java | 77 --------- .../analyses/OutlierDetectionTests.java | 60 ------- .../ExtractedFieldsDetectorTests.java | 9 +- .../test/ml/data_frame_analytics_crud.yml | 90 +++++------ .../test/ml/start_data_frame_analytics.yml | 4 +- 30 files changed, 450 insertions(+), 640 deletions(-) delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java delete mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfigTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/AbstractDataFrameAnalysis.java delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtils.java delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java delete mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java delete mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtilsTests.java delete mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetectionTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index f20ee9808f1c5..34811934ef367 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -132,6 +132,19 @@ import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MetricListEvaluationResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage; import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage; @@ -406,6 +419,22 @@ public List getNamedWriteables() { DataFrameAnalyticsTaskState::new), new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MACHINE_LEARNING, MachineLearningFeatureSetUsage::new), + // ML - Data frame analytics + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), + // ML - Data frame evaluation + new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), + BinarySoftClassification::new), + new NamedWriteableRegistry.Entry(EvaluationResult.class, MetricListEvaluationResult.NAME, MetricListEvaluationResult::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), Precision::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), Recall::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), AucRoc.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix.Result::new), + // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), // security diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java deleted file mode 100644 index cc4106218e098..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfig.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.dataframe; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ContextParser; -import org.elasticsearch.common.xcontent.ToXContentObject; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; - -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; - -public class DataFrameAnalysisConfig implements ToXContentObject, Writeable { - - public static ContextParser parser() { - return (p, c) -> new DataFrameAnalysisConfig(p.mapOrdered()); - } - - private final Map config; - - public DataFrameAnalysisConfig(Map config) { - this.config = Collections.unmodifiableMap(new HashMap<>(Objects.requireNonNull(config))); - if (config.size() != 1) { - throw ExceptionsHelper.badRequestException("A data frame analysis must specify exactly one analysis type"); - } - } - - public DataFrameAnalysisConfig(StreamInput in) throws IOException { - config = Collections.unmodifiableMap(in.readMap()); - } - - public Map asMap() { - return config; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeMap(config); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.map(config); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - DataFrameAnalysisConfig that = (DataFrameAnalysisConfig) o; - return Objects.equals(config, that.config); - } - - @Override - public int hashCode() { - return Objects.hash(config); - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java index f1fe8ced110a5..0e9acdd44a2fe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.core.ml.dataframe; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -15,16 +14,17 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; -import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Objects; @@ -42,9 +42,9 @@ public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { public static final ParseField ID = new ParseField("id"); public static final ParseField SOURCE = new ParseField("source"); public static final ParseField DEST = new ParseField("dest"); - public static final ParseField ANALYSES = new ParseField("analyses"); + public static final ParseField ANALYSIS = new ParseField("analysis"); public static final ParseField CONFIG_TYPE = new ParseField("config_type"); - public static final ParseField ANALYSES_FIELDS = new ParseField("analyses_fields"); + public static final ParseField ANALYZED_FIELDS = new ParseField("analyzed_fields"); public static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); public static final ParseField HEADERS = new ParseField("headers"); @@ -58,10 +58,10 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel parser.declareString(Builder::setId, ID); parser.declareObject(Builder::setSource, DataFrameAnalyticsSource.createParser(ignoreUnknownFields), SOURCE); parser.declareObject(Builder::setDest, DataFrameAnalyticsDest.createParser(ignoreUnknownFields), DEST); - parser.declareObjectArray(Builder::setAnalyses, DataFrameAnalysisConfig.parser(), ANALYSES); - parser.declareField(Builder::setAnalysesFields, + parser.declareObject(Builder::setAnalysis, (p, c) -> parseAnalysis(p, ignoreUnknownFields), ANALYSIS); + parser.declareField(Builder::setAnalyzedFields, (p, c) -> FetchSourceContext.fromXContent(p), - ANALYSES_FIELDS, + ANALYZED_FIELDS, OBJECT_ARRAY_BOOLEAN_OR_STRING); parser.declareField(Builder::setModelMemoryLimit, (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), MODEL_MEMORY_LIMIT.getPreferredName()), MODEL_MEMORY_LIMIT, VALUE); @@ -73,11 +73,19 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel return parser; } + private static DataFrameAnalysis parseAnalysis(XContentParser parser, boolean ignoreUnknownFields) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + DataFrameAnalysis analysis = parser.namedObject(DataFrameAnalysis.class, parser.currentName(), ignoreUnknownFields); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return analysis; + } + private final String id; private final DataFrameAnalyticsSource source; private final DataFrameAnalyticsDest dest; - private final List analyses; - private final FetchSourceContext analysesFields; + private final DataFrameAnalysis analysis; + private final FetchSourceContext analyzedFields; /** * This may be null up to the point of persistence, as the relationship with xpack.ml.max_model_memory_limit * depends on whether the user explicitly set the value or if the default was requested. null indicates @@ -90,20 +98,13 @@ public static ObjectParser createParser(boolean ignoreUnknownFiel private final Map headers; public DataFrameAnalyticsConfig(String id, DataFrameAnalyticsSource source, DataFrameAnalyticsDest dest, - List analyses, Map headers, ByteSizeValue modelMemoryLimit, - FetchSourceContext analysesFields) { + DataFrameAnalysis analysis, Map headers, ByteSizeValue modelMemoryLimit, + FetchSourceContext analyzedFields) { this.id = ExceptionsHelper.requireNonNull(id, ID); this.source = ExceptionsHelper.requireNonNull(source, SOURCE); this.dest = ExceptionsHelper.requireNonNull(dest, DEST); - this.analyses = ExceptionsHelper.requireNonNull(analyses, ANALYSES); - if (analyses.isEmpty()) { - throw new ElasticsearchParseException("One or more analyses are required"); - } - // TODO Add support for multiple analyses - if (analyses.size() > 1) { - throw new UnsupportedOperationException("Does not yet support multiple analyses"); - } - this.analysesFields = analysesFields; + this.analysis = ExceptionsHelper.requireNonNull(analysis, ANALYSIS); + this.analyzedFields = analyzedFields; this.modelMemoryLimit = modelMemoryLimit; this.headers = Collections.unmodifiableMap(headers); } @@ -112,8 +113,8 @@ public DataFrameAnalyticsConfig(StreamInput in) throws IOException { id = in.readString(); source = new DataFrameAnalyticsSource(in); dest = new DataFrameAnalyticsDest(in); - analyses = in.readList(DataFrameAnalysisConfig::new); - this.analysesFields = in.readOptionalWriteable(FetchSourceContext::new); + analysis = in.readNamedWriteable(DataFrameAnalysis.class); + this.analyzedFields = in.readOptionalWriteable(FetchSourceContext::new); this.modelMemoryLimit = in.readOptionalWriteable(ByteSizeValue::new); this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); } @@ -130,12 +131,12 @@ public DataFrameAnalyticsDest getDest() { return dest; } - public List getAnalyses() { - return analyses; + public DataFrameAnalysis getAnalysis() { + return analysis; } - public FetchSourceContext getAnalysesFields() { - return analysesFields; + public FetchSourceContext getAnalyzedFields() { + return analyzedFields; } public ByteSizeValue getModelMemoryLimit() { @@ -152,12 +153,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(ID.getPreferredName(), id); builder.field(SOURCE.getPreferredName(), source); builder.field(DEST.getPreferredName(), dest); - builder.field(ANALYSES.getPreferredName(), analyses); + + builder.startObject(ANALYSIS.getPreferredName()); + builder.field(analysis.getWriteableName(), analysis); + builder.endObject(); + if (params.paramAsBoolean(ToXContentParams.INCLUDE_TYPE, false)) { builder.field(CONFIG_TYPE.getPreferredName(), TYPE); } - if (analysesFields != null) { - builder.field(ANALYSES_FIELDS.getPreferredName(), analysesFields); + if (analyzedFields != null) { + builder.field(ANALYZED_FIELDS.getPreferredName(), analyzedFields); } builder.field(MODEL_MEMORY_LIMIT.getPreferredName(), getModelMemoryLimit().getStringRep()); if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { @@ -172,8 +177,8 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(id); source.writeTo(out); dest.writeTo(out); - out.writeList(analyses); - out.writeOptionalWriteable(analysesFields); + out.writeNamedWriteable(analysis); + out.writeOptionalWriteable(analyzedFields); out.writeOptionalWriteable(modelMemoryLimit); out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); } @@ -187,15 +192,15 @@ public boolean equals(Object o) { return Objects.equals(id, other.id) && Objects.equals(source, other.source) && Objects.equals(dest, other.dest) - && Objects.equals(analyses, other.analyses) + && Objects.equals(analysis, other.analysis) && Objects.equals(headers, other.headers) && Objects.equals(getModelMemoryLimit(), other.getModelMemoryLimit()) - && Objects.equals(analysesFields, other.analysesFields); + && Objects.equals(analyzedFields, other.analyzedFields); } @Override public int hashCode() { - return Objects.hash(id, source, dest, analyses, headers, getModelMemoryLimit(), analysesFields); + return Objects.hash(id, source, dest, analysis, headers, getModelMemoryLimit(), analyzedFields); } public static String documentId(String id) { @@ -207,8 +212,8 @@ public static class Builder { private String id; private DataFrameAnalyticsSource source; private DataFrameAnalyticsDest dest; - private List analyses; - private FetchSourceContext analysesFields; + private DataFrameAnalysis analysis; + private FetchSourceContext analyzedFields; private ByteSizeValue modelMemoryLimit; private ByteSizeValue maxModelMemoryLimit; private Map headers = Collections.emptyMap(); @@ -231,12 +236,12 @@ public Builder(DataFrameAnalyticsConfig config, ByteSizeValue maxModelMemoryLimi this.id = config.id; this.source = new DataFrameAnalyticsSource(config.source); this.dest = new DataFrameAnalyticsDest(config.dest); - this.analyses = new ArrayList<>(config.analyses); + this.analysis = config.analysis; this.headers = new HashMap<>(config.headers); this.modelMemoryLimit = config.modelMemoryLimit; this.maxModelMemoryLimit = maxModelMemoryLimit; - if (config.analysesFields != null) { - this.analysesFields = new FetchSourceContext(true, config.analysesFields.includes(), config.analysesFields.excludes()); + if (config.analyzedFields != null) { + this.analyzedFields = new FetchSourceContext(true, config.analyzedFields.includes(), config.analyzedFields.excludes()); } } @@ -259,13 +264,13 @@ public Builder setDest(DataFrameAnalyticsDest dest) { return this; } - public Builder setAnalyses(List analyses) { - this.analyses = ExceptionsHelper.requireNonNull(analyses, ANALYSES); + public Builder setAnalysis(DataFrameAnalysis analysis) { + this.analysis = ExceptionsHelper.requireNonNull(analysis, ANALYSIS); return this; } - public Builder setAnalysesFields(FetchSourceContext fields) { - this.analysesFields = fields; + public Builder setAnalyzedFields(FetchSourceContext fields) { + this.analyzedFields = fields; return this; } @@ -301,7 +306,7 @@ private void applyMaxModelMemoryLimit() { public DataFrameAnalyticsConfig build() { applyMaxModelMemoryLimit(); - return new DataFrameAnalyticsConfig(id, source, dest, analyses, headers, modelMemoryLimit, analysesFields); + return new DataFrameAnalyticsConfig(id, source, dest, analysis, headers, modelMemoryLimit, analyzedFields); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java new file mode 100644 index 0000000000000..f21533d917602 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; + +import java.util.Map; + +public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { + + Map getParams(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java new file mode 100644 index 0000000000000..a48a23e4a8393 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.ArrayList; +import java.util.List; + +public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + List namedXContent = new ArrayList<>(); + + namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> { + boolean ignoreUnknownFields = (boolean) c; + return OutlierDetection.fromXContent(p, ignoreUnknownFields); + })); + + return namedXContent; + } + + public List getNamedWriteables() { + List namedWriteables = new ArrayList<>(); + + namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), + OutlierDetection::new)); + + return namedWriteables; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java new file mode 100644 index 0000000000000..6d9e6cdd71163 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +public class OutlierDetection implements DataFrameAnalysis { + + public static final ParseField NAME = new ParseField("outlier_detection"); + + public static final ParseField NUMBER_NEIGHBORS = new ParseField("number_neighbors"); + public static final ParseField METHOD = new ParseField("method"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient, + a -> new OutlierDetection((Integer) a[0], (Method) a[1])); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_NEIGHBORS); + parser.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return Method.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, METHOD, ObjectParser.ValueType.STRING); + return parser; + } + + public static OutlierDetection fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + private final Integer numberNeighbors; + private final Method method; + + /** + * Constructs the outlier detection configuration + * @param numberNeighbors The number of neighbors. Leave unspecified for dynamic detection. + * @param method The method. Leave unspecified for a dynamic mixture of methods. + */ + public OutlierDetection(@Nullable Integer numberNeighbors, @Nullable Method method) { + if (numberNeighbors != null && numberNeighbors <= 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", NUMBER_NEIGHBORS.getPreferredName()); + } + + this.numberNeighbors = numberNeighbors; + this.method = method; + } + + /** + * Constructs the default outlier detection configuration + */ + public OutlierDetection() { + this(null, null); + } + + public OutlierDetection(StreamInput in) throws IOException { + numberNeighbors = in.readOptionalVInt(); + method = in.readBoolean() ? in.readEnum(Method.class) : null; + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(numberNeighbors); + + if (method != null) { + out.writeBoolean(true); + out.writeEnum(method); + } else { + out.writeBoolean(false); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (numberNeighbors != null) { + builder.field(NUMBER_NEIGHBORS.getPreferredName(), numberNeighbors); + } + if (method != null) { + builder.field(METHOD.getPreferredName(), method); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OutlierDetection that = (OutlierDetection) o; + return Objects.equals(numberNeighbors, that.numberNeighbors) && Objects.equals(method, that.method); + } + + @Override + public int hashCode() { + return Objects.hash(numberNeighbors, method); + } + + @Override + public Map getParams() { + Map params = new HashMap<>(); + if (numberNeighbors != null) { + // TODO change this to the constant NEIGHBORS when c++ is updated to match + params.put("number_neighbours", numberNeighbors); + } + if (method != null) { + params.put(METHOD.getPreferredName(), method); + } + return params; + } + + public enum Method { + LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; + + public static Method fromString(String value) { + return Method.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index c33b2f2943b6d..c9c54d4f822de 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -29,6 +29,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -413,17 +414,17 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .endObject() .endObject() .endObject() - .startObject(DataFrameAnalyticsConfig.ANALYSES_FIELDS.getPreferredName()) + .startObject(DataFrameAnalyticsConfig.ANALYZED_FIELDS.getPreferredName()) .field(ENABLED, false) .endObject() - .startObject(DataFrameAnalyticsConfig.ANALYSES.getPreferredName()) + .startObject(DataFrameAnalyticsConfig.ANALYSIS.getPreferredName()) .startObject(PROPERTIES) - .startObject("outlier_detection") + .startObject(OutlierDetection.NAME.getPreferredName()) .startObject(PROPERTIES) - .startObject("number_neighbours") + .startObject(OutlierDetection.NUMBER_NEIGHBORS.getPreferredName()) .field(TYPE, INTEGER) .endObject() - .startObject("method") + .startObject(OutlierDetection.METHOD.getPreferredName()) .field(TYPE, KEYWORD) .endObject() .endObject() diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 7530c120b0fd2..4dfb825d330f4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -262,15 +263,15 @@ public final class ReservedFieldNames { DataFrameAnalyticsConfig.ID.getPreferredName(), DataFrameAnalyticsConfig.SOURCE.getPreferredName(), DataFrameAnalyticsConfig.DEST.getPreferredName(), - DataFrameAnalyticsConfig.ANALYSES.getPreferredName(), - DataFrameAnalyticsConfig.ANALYSES_FIELDS.getPreferredName(), + DataFrameAnalyticsConfig.ANALYSIS.getPreferredName(), + DataFrameAnalyticsConfig.ANALYZED_FIELDS.getPreferredName(), DataFrameAnalyticsDest.INDEX.getPreferredName(), DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(), DataFrameAnalyticsSource.INDEX.getPreferredName(), DataFrameAnalyticsSource.QUERY.getPreferredName(), - "outlier_detection", - "method", - "number_neighbours", + OutlierDetection.NAME.getPreferredName(), + OutlierDetection.NUMBER_NEIGHBORS.getPreferredName(), + OutlierDetection.METHOD.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE }; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java index 0c2a90195d238..38a3396316602 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import java.util.ArrayList; import java.util.Collections; @@ -23,14 +24,18 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractStreamable @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); } @Override protected NamedXContentRegistry xContentRegistry() { - SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java index 633f34fd88576..d00fa4384be8a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java @@ -13,9 +13,12 @@ import org.elasticsearch.test.AbstractStreamableXContentTestCase; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Request; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.junit.Before; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; public class PutDataFrameAnalyticsActionRequestTests extends AbstractStreamableXContentTestCase { @@ -28,14 +31,18 @@ public void setUpId() { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); } @Override protected NamedXContentRegistry xContentRegistry() { - SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java index 7830f874a4d6e..c9f678b13df2a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java @@ -12,21 +12,28 @@ import org.elasticsearch.test.AbstractStreamableTestCase; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; public class PutDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); } @Override protected NamedXContentRegistry xContentRegistry() { - SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfigTests.java deleted file mode 100644 index a5dc889eea3db..0000000000000 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalysisConfigTests.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.dataframe; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.test.AbstractSerializingTestCase; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -public class DataFrameAnalysisConfigTests extends AbstractSerializingTestCase { - - @Override - protected DataFrameAnalysisConfig createTestInstance() { - return randomConfig(); - } - - @Override - protected DataFrameAnalysisConfig doParseInstance(XContentParser parser) throws IOException { - return DataFrameAnalysisConfig.parser().parse(parser, null); - } - - @Override - protected Writeable.Reader instanceReader() { - return DataFrameAnalysisConfig::new; - } - - public static DataFrameAnalysisConfig randomConfig() { - Map configParams = new HashMap<>(); - int count = randomIntBetween(1, 5); - for (int i = 0; i < count; i++) { - if (randomBoolean()) { - configParams.put(randomAlphaOfLength(10), randomInt()); - } else { - configParams.put(randomAlphaOfLength(10), randomAlphaOfLength(10)); - } - } - Map config = new HashMap<>(); - config.put(randomAlphaOfLength(10), configParams); - return new DataFrameAnalysisConfig(config); - } -} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index d45043714da98..b12df363c0c5d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -27,9 +27,12 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -51,14 +54,18 @@ protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedWriteableRegistry(searchModule.getNamedWriteables()); + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); } @Override protected NamedXContentRegistry xContentRegistry() { - SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); } @Override @@ -78,14 +85,13 @@ public static DataFrameAnalyticsConfig createRandom(String id) { public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) { DataFrameAnalyticsSource source = DataFrameAnalyticsSourceTests.createRandom(); DataFrameAnalyticsDest dest = DataFrameAnalyticsDestTests.createRandom(); - List analyses = Collections.singletonList(DataFrameAnalysisConfigTests.randomConfig()); DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder() .setId(id) - .setAnalyses(analyses) + .setAnalysis(OutlierDetectionTests.createRandom()) .setSource(source) .setDest(dest); if (randomBoolean()) { - builder.setAnalysesFields(new FetchSourceContext(true, + builder.setAnalyzedFields(new FetchSourceContext(true, generateRandomStringArray(10, 10, false, false), generateRandomStringArray(10, 10, false, false))); } @@ -105,7 +111,7 @@ public static String randomValidId() { //query:match:type stopped being supported in 6.x " \"source\": {\"index\":\"my-index\", \"query\": {\"match\" : {\"query\":\"fieldName\", \"type\": \"phrase\"}}},\n" + " \"dest\": {\"index\":\"dest-index\"},\n" + - " \"analyses\": {\"outlier_detection\": {\"number_neighbours\": 10}}\n" + + " \"analysis\": {\"outlier_detection\": {\"number_neighbors\": 10}}\n" + "}"; private static final String MODERN_QUERY_DATA_FRAME_ANALYTICS = "{\n" + @@ -113,7 +119,7 @@ public static String randomValidId() { // match_all if parsed, adds default values in the options " \"source\": {\"index\":\"my-index\", \"query\": {\"match_all\" : {}}},\n" + " \"dest\": {\"index\":\"dest-index\"},\n" + - " \"analyses\": {\"outlier_detection\": {\"number_neighbours\": 10}}\n" + + " \"analysis\": {\"outlier_detection\": {\"number_neighbors\": 10}}\n" + "}"; public void testQueryConfigStoresUserInputOnly() throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java new file mode 100644 index 0000000000000..702afb7bb1c25 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +public class OutlierDetectionTests extends AbstractSerializingTestCase { + + @Override + protected OutlierDetection doParseInstance(XContentParser parser) throws IOException { + return OutlierDetection.fromXContent(parser, false); + } + + @Override + protected OutlierDetection createTestInstance() { + return createRandom(); + } + + public static OutlierDetection createRandom() { + Integer numberNeighbors = randomBoolean() ? null : randomIntBetween(1, 20); + OutlierDetection.Method method = randomBoolean() ? null : randomFrom(OutlierDetection.Method.values()); + return new OutlierDetection(numberNeighbors, method); + } + + @Override + protected Writeable.Reader instanceReader() { + return OutlierDetection::new; + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 33cbce301b4f3..51d2ef93e12b0 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -46,9 +46,8 @@ integTestRunner { 'ml/data_frame_analytics_crud/Test put config given missing dest', 'ml/data_frame_analytics_crud/Test put config given dest with empty index', 'ml/data_frame_analytics_crud/Test put config given dest without index', - 'ml/data_frame_analytics_crud/Test put config given missing analyses', - 'ml/data_frame_analytics_crud/Test put config given empty analyses', - 'ml/data_frame_analytics_crud/Test put config given two analyses', + 'ml/data_frame_analytics_crud/Test put config given missing analysis', + 'ml/data_frame_analytics_crud/Test put config given empty analysis', 'ml/data_frame_analytics_crud/Test get given missing analytics', 'ml/data_frame_analytics_crud/Test delete given missing config', 'ml/data_frame_analytics_crud/Test max model memory limit', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index cce6b765d388a..0191967fe31e8 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -15,15 +15,13 @@ import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.junit.After; -import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; @@ -151,9 +149,7 @@ private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); configBuilder.setDest(new DataFrameAnalyticsDest(sourceIndex + "-results", resultsField)); - Map analysisConfig = new HashMap<>(); - analysisConfig.put("outlier_detection", Collections.emptyMap()); - configBuilder.setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig(analysisConfig))); + configBuilder.setAnalysis(new OutlierDetection()); return configBuilder.build(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 98417cea897b7..95c026e053c86 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -115,6 +115,7 @@ import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction; import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; @@ -855,15 +856,11 @@ static long machineMemoryFromStats(OsStats stats) { return mem; } - @Override - public List getNamedWriteables() { - List namedWriteables = new ArrayList<>(); - namedWriteables.addAll(new MlEvaluationNamedXContentProvider().getNamedWriteables()); - return namedWriteables; - } - @Override public List getNamedXContent() { - return new MlEvaluationNamedXContentProvider().getNamedXContentParsers(); + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return namedXContent; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java index addabdb625553..372fb74edd198 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutDataFrameAnalyticsAction.java @@ -36,7 +36,6 @@ import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; import org.elasticsearch.xpack.core.security.authz.permission.ResourcePrivileges; import org.elasticsearch.xpack.core.security.support.Exceptions; -import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; import java.io.IOException; @@ -147,6 +146,5 @@ private void validateConfig(DataFrameAnalyticsConfig config) { throw ExceptionsHelper.badRequestException("id [{}] is too long; must not contain more than {} characters", config.getId(), MlStrings.ID_LENGTH_LIMIT); } - DataFrameAnalysesUtils.readAnalyses(config.getAnalyses()); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/AbstractDataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/AbstractDataFrameAnalysis.java deleted file mode 100644 index 90bcc839bb361..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/AbstractDataFrameAnalysis.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.dataframe.analyses; - -import org.elasticsearch.common.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Map; - -public abstract class AbstractDataFrameAnalysis implements DataFrameAnalysis { - - private static final String NAME = "name"; - private static final String PARAMETERS = "parameters"; - - @Override - public final XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(NAME, getType()); - builder.field(PARAMETERS, getParams()); - builder.endObject(); - return builder; - } - - protected abstract Map getParams(); -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtils.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtils.java deleted file mode 100644 index 5151d0c0c6e8d..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtils.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.dataframe.analyses; - -import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -public final class DataFrameAnalysesUtils { - - private static final Map factories; - - static { - factories = new HashMap<>(); - factories.put(DataFrameAnalysis.Type.OUTLIER_DETECTION, new OutlierDetection.Factory()); - } - - private DataFrameAnalysesUtils() {} - - public static List readAnalyses(List analyses) { - return analyses.stream().map(DataFrameAnalysesUtils::readAnalysis).collect(Collectors.toList()); - } - - static DataFrameAnalysis readAnalysis(DataFrameAnalysisConfig config) { - Map configMap = config.asMap(); - DataFrameAnalysis.Type analysisType = DataFrameAnalysis.Type.fromString(configMap.keySet().iterator().next()); - DataFrameAnalysis.Factory factory = factories.get(analysisType); - Map analysisConfig = castAsMapAndCopy(analysisType, configMap.get(analysisType.toString())); - DataFrameAnalysis dataFrameAnalysis = factory.create(analysisConfig); - if (analysisConfig.isEmpty() == false) { - throw new ElasticsearchParseException("Data frame analysis [{}] does not support one or more provided parameters {}", - analysisType, analysisConfig.keySet()); - } - return dataFrameAnalysis; - } - - private static Map castAsMapAndCopy(DataFrameAnalysis.Type analysisType, Object obj) { - try { - return new HashMap<>((Map) obj); - } catch (ClassCastException e) { - throw new ElasticsearchParseException("[{}] expected to be a map but was of type [{}]", analysisType, obj.getClass().getName()); - } - } - - @Nullable - static Integer readInt(DataFrameAnalysis.Type analysisType, Map config, String property) { - Object value = config.remove(property); - if (value == null) { - return null; - } - try { - return (int) value; - } catch (ClassCastException e) { - throw new ElasticsearchParseException("Property [{}] of analysis [{}] should be of type [Integer] but was [{}]", - property, analysisType, value.getClass().getSimpleName()); - } - } - - @Nullable - static String readString(DataFrameAnalysis.Type analysisType, Map config, String property) { - Object value = config.remove(property); - if (value == null) { - return null; - } - try { - return (String) value; - } catch (ClassCastException e) { - throw new ElasticsearchParseException("Property [{}] of analysis [{}] should be of type [String] but was [{}]", - property, analysisType, value.getClass().getSimpleName()); - } - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java deleted file mode 100644 index 9fdd093fa324e..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysis.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.dataframe.analyses; - -import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.common.xcontent.ToXContentObject; - -import java.util.Locale; -import java.util.Map; - -public interface DataFrameAnalysis extends ToXContentObject { - - enum Type { - OUTLIER_DETECTION; - - public static Type fromString(String value) { - try { - return Type.valueOf(value.toUpperCase(Locale.ROOT)); - } catch (IllegalArgumentException e) { - throw new ElasticsearchParseException("Unknown analysis type [{}]", value); - } - } - - @Override - public String toString() { - return name().toLowerCase(Locale.ROOT); - } - } - - Type getType(); - - interface Factory { - - /** - * Creates a data frame analysis based on the specified map of maps config. - * - * @param config The configuration for the analysis - * - * Note: Implementations are responsible for removing the used configuration keys, so that after - * creation it can be verified that all configurations settings have been used. - */ - DataFrameAnalysis create(Map config); - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java deleted file mode 100644 index 47f614ba658f6..0000000000000 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetection.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.dataframe.analyses; - -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; - -public class OutlierDetection extends AbstractDataFrameAnalysis { - - public enum Method { - LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; - - public static Method fromString(String value) { - return Method.valueOf(value.toUpperCase(Locale.ROOT)); - } - - @Override - public String toString() { - return name().toLowerCase(Locale.ROOT); - } - } - - public static final String NUMBER_NEIGHBOURS = "number_neighbours"; - public static final String METHOD = "method"; - - private final Integer numberNeighbours; - private final Method method; - - public OutlierDetection(Integer numberNeighbours, Method method) { - this.numberNeighbours = numberNeighbours; - this.method = method; - } - - @Override - public Type getType() { - return Type.OUTLIER_DETECTION; - } - - @Override - protected Map getParams() { - Map params = new HashMap<>(); - if (numberNeighbours != null) { - params.put(NUMBER_NEIGHBOURS, numberNeighbours); - } - if (method != null) { - params.put(METHOD, method); - } - return params; - } - - static class Factory implements DataFrameAnalysis.Factory { - - @Override - public DataFrameAnalysis create(Map config) { - Integer numberNeighbours = DataFrameAnalysesUtils.readInt(Type.OUTLIER_DETECTION, config, NUMBER_NEIGHBOURS); - String method = DataFrameAnalysesUtils.readString(Type.OUTLIER_DETECTION, config, METHOD); - return new OutlierDetection(numberNeighbours, method == null ? null : Method.fromString(method)); - } - } -} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index 1c363b62bfaef..f1302ef72a3bd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -110,12 +110,12 @@ private void removeFieldsWithIncompatibleTypes(Set fields) { } private void includeAndExcludeFields(Set fields, String index) { - FetchSourceContext analysesFields = config.getAnalysesFields(); - if (analysesFields == null) { + FetchSourceContext analyzedFields = config.getAnalyzedFields(); + if (analyzedFields == null) { return; } - String includes = analysesFields.includes().length == 0 ? "*" : Strings.arrayToCommaDelimitedString(analysesFields.includes()); - String excludes = Strings.arrayToCommaDelimitedString(analysesFields.excludes()); + String includes = analyzedFields.includes().length == 0 ? "*" : Strings.arrayToCommaDelimitedString(analyzedFields.includes()); + String excludes = Strings.arrayToCommaDelimitedString(analyzedFields.excludes()); if (Regex.isMatchAllPattern(includes) && excludes.isEmpty()) { return; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 98f91fb601f87..226498376bbe1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -8,7 +8,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import java.io.IOException; import java.util.Objects; @@ -51,8 +51,26 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(MEMORY_LIMIT, memoryLimit.getBytes()); builder.field(THREADS, threads); builder.field(RESULTS_FIELD, resultsField); - builder.field(ANALYSIS, analysis); + builder.field(ANALYSIS, new DataFrameAnalysisWrapper(analysis)); builder.endObject(); return builder; } + + private static class DataFrameAnalysisWrapper implements ToXContentObject { + + private final DataFrameAnalysis analysis; + + private DataFrameAnalysisWrapper(DataFrameAnalysis analysis) { + this.analysis = analysis; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("name", analysis.getWriteableName()); + builder.field("parameters", analysis.getParams()); + builder.endObject(); + return builder; + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 6f939c7c18b11..52b308a3aa98e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -18,8 +18,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; -import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysesUtils; -import org.elasticsearch.xpack.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; @@ -145,12 +143,8 @@ private AnalyticsProcess createProcess(String jobId, AnalyticsProcessConfig anal private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); - List dataFrameAnalyses = DataFrameAnalysesUtils.readAnalyses(config.getAnalyses()); - // TODO We will not need this assertion after we add support for multiple analyses - assert dataFrameAnalyses.size() == 1; - AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, - config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), dataFrameAnalyses.get(0)); + config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), config.getAnalysis()); return processConfig; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtilsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtilsTests.java deleted file mode 100644 index b95fd32f7288d..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/DataFrameAnalysesUtilsTests.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.dataframe.analyses; - -import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.XContentType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; - -public class DataFrameAnalysesUtilsTests extends ESTestCase { - - public void testReadAnalysis_GivenEmptyAnalysisList() { - assertThat(DataFrameAnalysesUtils.readAnalyses(Collections.emptyList()).isEmpty(), is(true)); - } - - public void testReadAnalysis_GivenUnknownAnalysis() { - String analysisJson = "{\"unknown_analysis\": {}}"; - DataFrameAnalysisConfig analysisConfig = createAnalysisConfig(analysisJson); - - ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, - () -> DataFrameAnalysesUtils.readAnalyses(Collections.singletonList(analysisConfig))); - - assertThat(e.getMessage(), equalTo("Unknown analysis type [unknown_analysis]")); - } - - public void testReadAnalysis_GivenAnalysisIsNotAnObject() { - String analysisJson = "{\"outlier_detection\": 42}"; - DataFrameAnalysisConfig analysisConfig = createAnalysisConfig(analysisJson); - - ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, - () -> DataFrameAnalysesUtils.readAnalyses(Collections.singletonList(analysisConfig))); - - assertThat(e.getMessage(), equalTo("[outlier_detection] expected to be a map but was of type [java.lang.Integer]")); - } - - public void testReadAnalysis_GivenUnusedParameters() { - String analysisJson = "{\"outlier_detection\": {\"number_neighbours\":42, \"foo\": 1}}"; - DataFrameAnalysisConfig analysisConfig = createAnalysisConfig(analysisJson); - - ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, - () -> DataFrameAnalysesUtils.readAnalyses(Collections.singletonList(analysisConfig))); - - assertThat(e.getMessage(), equalTo("Data frame analysis [outlier_detection] does not support one or more provided " + - "parameters [foo]")); - } - - public void testReadAnalysis_GivenValidOutlierDetection() { - String analysisJson = "{\"outlier_detection\": {\"number_neighbours\":42}}"; - DataFrameAnalysisConfig analysisConfig = createAnalysisConfig(analysisJson); - - List analyses = DataFrameAnalysesUtils.readAnalyses(Collections.singletonList(analysisConfig)); - - assertThat(analyses.size(), equalTo(1)); - assertThat(analyses.get(0), is(instanceOf(OutlierDetection.class))); - OutlierDetection outlierDetection = (OutlierDetection) analyses.get(0); - assertThat(outlierDetection.getParams().size(), equalTo(1)); - assertThat(outlierDetection.getParams().get("number_neighbours"), equalTo(42)); - } - - private static DataFrameAnalysisConfig createAnalysisConfig(String json) { - Map asMap = XContentHelper.convertToMap(new BytesArray(json), true, XContentType.JSON).v2(); - return new DataFrameAnalysisConfig(asMap); - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetectionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetectionTests.java deleted file mode 100644 index 59a838acc8cd8..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/analyses/OutlierDetectionTests.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.dataframe.analyses; - -import org.elasticsearch.ElasticsearchParseException; -import org.elasticsearch.test.ESTestCase; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; - -public class OutlierDetectionTests extends ESTestCase { - - public void testCreate_GivenNumberNeighboursNotInt() { - Map config = new HashMap<>(); - config.put(OutlierDetection.NUMBER_NEIGHBOURS, "42"); - - DataFrameAnalysis.Factory factory = new OutlierDetection.Factory(); - - ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, () -> factory.create(config)); - assertThat(e.getMessage(), equalTo("Property [number_neighbours] of analysis [outlier_detection] should be of " + - "type [Integer] but was [String]")); - } - - public void testCreate_GivenMethodNotString() { - Map config = new HashMap<>(); - config.put(OutlierDetection.METHOD, 42); - - DataFrameAnalysis.Factory factory = new OutlierDetection.Factory(); - - ElasticsearchParseException e = expectThrows(ElasticsearchParseException.class, () -> factory.create(config)); - assertThat(e.getMessage(), equalTo("Property [method] of analysis [outlier_detection] should be of " + - "type [String] but was [Integer]")); - } - - public void testCreate_GivenEmptyParams() { - DataFrameAnalysis.Factory factory = new OutlierDetection.Factory(); - OutlierDetection outlierDetection = (OutlierDetection) factory.create(Collections.emptyMap()); - assertThat(outlierDetection.getParams().isEmpty(), is(true)); - } - - public void testCreate_GivenFullParams() { - Map config = new HashMap<>(); - config.put(OutlierDetection.NUMBER_NEIGHBOURS, 42); - config.put(OutlierDetection.METHOD, "ldof"); - - DataFrameAnalysis.Factory factory = new OutlierDetection.Factory(); - OutlierDetection outlierDetection = (OutlierDetection) factory.create(config); - - assertThat(outlierDetection.getParams().size(), equalTo(2)); - assertThat(outlierDetection.getParams().get(OutlierDetection.NUMBER_NEIGHBOURS), equalTo(42)); - assertThat(outlierDetection.getParams().get(OutlierDetection.METHOD), equalTo(OutlierDetection.Method.LDOF)); - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index 905349aa72840..3aa6bfd6480d1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -10,10 +10,10 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalysisConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; @@ -222,13 +222,12 @@ private static DataFrameAnalyticsConfig buildAnalyticsConfig() { return buildAnalyticsConfig(null); } - private static DataFrameAnalyticsConfig buildAnalyticsConfig(FetchSourceContext analysesFields) { + private static DataFrameAnalyticsConfig buildAnalyticsConfig(FetchSourceContext analyzedFields) { return new DataFrameAnalyticsConfig.Builder("foo") .setSource(new DataFrameAnalyticsSource(SOURCE_INDEX, null)) .setDest(new DataFrameAnalyticsDest(DEST_INDEX, null)) - .setAnalysesFields(analysesFields) - .setAnalyses(Collections.singletonList(new DataFrameAnalysisConfig( - Collections.singletonMap("outlier_detection", Collections.emptyMap())))) + .setAnalyzedFields(analyzedFields) + .setAnalysis(new OutlierDetection()) .build(); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 38feb3fc58792..1cf450b26d669 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -40,15 +40,15 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}], - "analyses_fields": [ "obj1.*", "obj2.*" ] + "analysis": {"outlier_detection":{}}, + "analyzed_fields": [ "obj1.*", "obj2.*" ] } - match: { id: "simple-outlier-detection-with-query" } - match: { source.index: "index-source" } - match: { source.query: {"term" : { "user" : "Kimchy"} } } - match: { dest.index: "index-dest" } - - match: { analyses: [{"outlier_detection":{}}] } - - match: { analyses_fields: {"includes" : ["obj1.*", "obj2.*" ], "excludes": [] } } + - match: { analysis: {"outlier_detection":{}} } + - match: { analyzed_fields: {"includes" : ["obj1.*", "obj2.*" ], "excludes": [] } } --- "Test put config with security headers in the body": @@ -65,7 +65,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}], + "analysis": {"outlier_detection":{}}, "headers":{ "a_security_header" : "secret" } } @@ -83,13 +83,13 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - match: { id: "simple-outlier-detection" } - match: { source.index: "index-source" } - match: { source.query: {"match_all" : {} } } - match: { dest.index: "index-dest" } - - match: { analyses: [{"outlier_detection":{}}] } + - match: { analysis: {"outlier_detection":{}} } --- "Test put config with inconsistent body/param ids": @@ -107,7 +107,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } --- @@ -125,7 +125,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } --- @@ -143,7 +143,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}], + "analysis": {"outlier_detection":{}}, "unknown_field": 42 } @@ -151,7 +151,7 @@ "Test put config with unknown field in outlier detection analysis": - do: - catch: /Data frame analysis \[outlier_detection\] does not support one or more provided parameters \[unknown_field\]/ + catch: /unknown field \[unknown_field\]/ ml.put_data_frame_analytics: id: "unknown_field" body: > @@ -162,7 +162,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{"unknown_field": 42}}] + "analysis": {"outlier_detection":{"unknown_field":42}} } --- @@ -177,7 +177,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } --- @@ -195,7 +195,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } --- @@ -212,7 +212,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } --- @@ -227,7 +227,7 @@ "source": { "index": "index-source" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } --- @@ -245,7 +245,7 @@ "dest": { "index": "" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } --- @@ -262,14 +262,14 @@ }, "dest": { }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } --- -"Test put config given missing analyses": +"Test put config given missing analysis": - do: - catch: /\[analyses\] must not be null/ + catch: /\[analysis\] must not be null/ ml.put_data_frame_analytics: id: "simple-outlier-detection" body: > @@ -283,10 +283,10 @@ } --- -"Test put config given empty analyses": +"Test put config given empty analysis": - do: - catch: /One or more analyses are required/ + catch: /parsing_exception/ ml.put_data_frame_analytics: id: "simple-outlier-detection" body: > @@ -297,25 +297,7 @@ "dest": { "index": "index-dest" }, - "analyses": [] - } - ---- -"Test put config given two analyses": - - - do: - catch: /Does not yet support multiple analyses/ - ml.put_data_frame_analytics: - id: "simple-outlier-detection" - body: > - { - "source": { - "index": "index-source" - }, - "dest": { - "index": "index-dest" - }, - "analyses": [{"outlier_detection":{}}, {"outlier_detection":{}}] + "analysis": {} } --- @@ -332,7 +314,7 @@ "dest": { "index": "index-foo-1_dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - do: @@ -346,7 +328,7 @@ "dest": { "index": "index-foo-2_dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - match: { id: "foo-2" } @@ -361,7 +343,7 @@ "dest": { "index": "index-bar_dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - match: { id: "bar" } @@ -429,7 +411,7 @@ "dest": { "index": "index-foo-1_dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - do: @@ -443,7 +425,7 @@ "dest": { "index": "index-foo-2_dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - match: { id: "foo-2" } @@ -458,7 +440,7 @@ "dest": { "index": "index-bar_dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - match: { id: "bar" } @@ -527,7 +509,7 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - do: @@ -575,9 +557,9 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}], + "analysis": {"outlier_detection":{}}, "model_memory_limit": "8gb", - "analyses_fields": [ "obj1.*", "obj2.*" ] + "analyzed_fields": [ "obj1.*", "obj2.*" ] } # Request using default higher than limit gets silently capped @@ -593,15 +575,15 @@ "dest": { "index": "index-dest" }, - "analyses": [{"outlier_detection":{}}], - "analyses_fields": [ "obj1.*", "obj2.*" ] + "analysis": {"outlier_detection":{}}, + "analyzed_fields": [ "obj1.*", "obj2.*" ] } - match: { id: "simple-outlier-detection-with-query" } - match: { source.index: "index-source" } - match: { source.query: {"term" : { "user" : "Kimchy"} } } - match: { dest.index: "index-dest" } - - match: { analyses: [{"outlier_detection":{}}] } - - match: { analyses_fields: {"includes" : ["obj1.*", "obj2.*" ], "excludes": [] } } + - match: { analysis: {"outlier_detection":{}} } + - match: { analyzed_fields: {"includes" : ["obj1.*", "obj2.*" ], "excludes": [] } } - match: { model_memory_limit: "20mb" } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml index 8149565238730..3cc2e91483e18 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml @@ -19,7 +19,7 @@ "dest": { "index": "missing-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - do: @@ -45,7 +45,7 @@ "dest": { "index": "empty-index-dest" }, - "analyses": [{"outlier_detection":{}}] + "analysis": {"outlier_detection":{}} } - do: From f1140d5bce562679fc769bbb2789f3d776d7b906 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 18 Apr 2019 18:01:46 +0300 Subject: [PATCH 35/67] [FEATURE][ML] Rename number_neighbors to n_neighbors (#41242) Shorter plus american spelling. What's more to like? --- .../dataframe/analyses/OutlierDetection.java | 33 +++++++++---------- .../persistence/ElasticsearchMappings.java | 2 +- .../ml/job/results/ReservedFieldNames.java | 2 +- .../DataFrameAnalyticsConfigTests.java | 4 +-- .../analyses/OutlierDetectionTests.java | 19 +++++++++++ 5 files changed, 39 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 6d9e6cdd71163..5cd00fa979550 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -25,7 +25,7 @@ public class OutlierDetection implements DataFrameAnalysis { public static final ParseField NAME = new ParseField("outlier_detection"); - public static final ParseField NUMBER_NEIGHBORS = new ParseField("number_neighbors"); + public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); public static final ParseField METHOD = new ParseField("method"); private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); @@ -34,7 +34,7 @@ public class OutlierDetection implements DataFrameAnalysis { private static ConstructingObjectParser createParser(boolean lenient) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient, a -> new OutlierDetection((Integer) a[0], (Method) a[1])); - parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUMBER_NEIGHBORS); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), N_NEIGHBORS); parser.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { if (p.currentToken() == XContentParser.Token.VALUE_STRING) { return Method.fromString(p.text()); @@ -48,20 +48,20 @@ public static OutlierDetection fromXContent(XContentParser parser, boolean ignor return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); } - private final Integer numberNeighbors; + private final Integer nNeighbors; private final Method method; /** * Constructs the outlier detection configuration - * @param numberNeighbors The number of neighbors. Leave unspecified for dynamic detection. + * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. * @param method The method. Leave unspecified for a dynamic mixture of methods. */ - public OutlierDetection(@Nullable Integer numberNeighbors, @Nullable Method method) { - if (numberNeighbors != null && numberNeighbors <= 0) { - throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", NUMBER_NEIGHBORS.getPreferredName()); + public OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method) { + if (nNeighbors != null && nNeighbors <= 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", N_NEIGHBORS.getPreferredName()); } - this.numberNeighbors = numberNeighbors; + this.nNeighbors = nNeighbors; this.method = method; } @@ -73,7 +73,7 @@ public OutlierDetection() { } public OutlierDetection(StreamInput in) throws IOException { - numberNeighbors = in.readOptionalVInt(); + nNeighbors = in.readOptionalVInt(); method = in.readBoolean() ? in.readEnum(Method.class) : null; } @@ -84,7 +84,7 @@ public String getWriteableName() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalVInt(numberNeighbors); + out.writeOptionalVInt(nNeighbors); if (method != null) { out.writeBoolean(true); @@ -97,8 +97,8 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - if (numberNeighbors != null) { - builder.field(NUMBER_NEIGHBORS.getPreferredName(), numberNeighbors); + if (nNeighbors != null) { + builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors); } if (method != null) { builder.field(METHOD.getPreferredName(), method); @@ -112,20 +112,19 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; OutlierDetection that = (OutlierDetection) o; - return Objects.equals(numberNeighbors, that.numberNeighbors) && Objects.equals(method, that.method); + return Objects.equals(nNeighbors, that.nNeighbors) && Objects.equals(method, that.method); } @Override public int hashCode() { - return Objects.hash(numberNeighbors, method); + return Objects.hash(nNeighbors, method); } @Override public Map getParams() { Map params = new HashMap<>(); - if (numberNeighbors != null) { - // TODO change this to the constant NEIGHBORS when c++ is updated to match - params.put("number_neighbours", numberNeighbors); + if (nNeighbors != null) { + params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); } if (method != null) { params.put(METHOD.getPreferredName(), method); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index c9c54d4f822de..ac0c4a420448b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -421,7 +421,7 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(PROPERTIES) .startObject(OutlierDetection.NAME.getPreferredName()) .startObject(PROPERTIES) - .startObject(OutlierDetection.NUMBER_NEIGHBORS.getPreferredName()) + .startObject(OutlierDetection.N_NEIGHBORS.getPreferredName()) .field(TYPE, INTEGER) .endObject() .startObject(OutlierDetection.METHOD.getPreferredName()) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index bef72381f5d0f..75b4d9e7777ae 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -275,7 +275,7 @@ public final class ReservedFieldNames { DataFrameAnalyticsSource.INDEX.getPreferredName(), DataFrameAnalyticsSource.QUERY.getPreferredName(), OutlierDetection.NAME.getPreferredName(), - OutlierDetection.NUMBER_NEIGHBORS.getPreferredName(), + OutlierDetection.N_NEIGHBORS.getPreferredName(), OutlierDetection.METHOD.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index b12df363c0c5d..a5df1f83c3d37 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -111,7 +111,7 @@ public static String randomValidId() { //query:match:type stopped being supported in 6.x " \"source\": {\"index\":\"my-index\", \"query\": {\"match\" : {\"query\":\"fieldName\", \"type\": \"phrase\"}}},\n" + " \"dest\": {\"index\":\"dest-index\"},\n" + - " \"analysis\": {\"outlier_detection\": {\"number_neighbors\": 10}}\n" + + " \"analysis\": {\"outlier_detection\": {\"n_neighbors\": 10}}\n" + "}"; private static final String MODERN_QUERY_DATA_FRAME_ANALYTICS = "{\n" + @@ -119,7 +119,7 @@ public static String randomValidId() { // match_all if parsed, adds default values in the options " \"source\": {\"index\":\"my-index\", \"query\": {\"match_all\" : {}}},\n" + " \"dest\": {\"index\":\"dest-index\"},\n" + - " \"analysis\": {\"outlier_detection\": {\"number_neighbors\": 10}}\n" + + " \"analysis\": {\"outlier_detection\": {\"n_neighbors\": 10}}\n" + "}"; public void testQueryConfigStoresUserInputOnly() throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java index 702afb7bb1c25..0e3d826593258 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java @@ -10,6 +10,10 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; public class OutlierDetectionTests extends AbstractSerializingTestCase { @@ -33,4 +37,19 @@ public static OutlierDetection createRandom() { protected Writeable.Reader instanceReader() { return OutlierDetection::new; } + + public void testGetParams_GivenDefaults() { + OutlierDetection outlierDetection = new OutlierDetection(); + assertThat(outlierDetection.getParams().isEmpty(), is(true)); + } + + public void testGetParams_GivenExplicitValues() { + OutlierDetection outlierDetection = new OutlierDetection(42, OutlierDetection.Method.LDOF); + + Map params = outlierDetection.getParams(); + + assertThat(params.size(), equalTo(2)); + assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42)); + assertThat(params.get(OutlierDetection.METHOD.getPreferredName()), equalTo(OutlierDetection.Method.LDOF)); + } } From 1cd470cb1d7b4781fb993ecf9c91e74399a7d70a Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 23 Apr 2019 17:11:48 +0100 Subject: [PATCH 36/67] Fix failing test as a result of a bad merge --- .../xpack/core/ml/job/results/ReservedFieldNames.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 75b4d9e7777ae..f727f637b972b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -282,6 +282,7 @@ public final class ReservedFieldNames { GetResult._ID, GetResult._INDEX, + GetResult._TYPE }; /** From 60378e0f56639e36c7fc1cfb1fcb63924be92ca3 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Wed, 24 Apr 2019 17:21:31 +0100 Subject: [PATCH 37/67] [ML] Make a copy of unmodifiable sett (#41490) After #34071 the FieldCapabilitiesResponse response map is unmodifiable --- .../xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index f1302ef72a3bd..96f0181b1416c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -67,7 +68,7 @@ public class ExtractedFieldsDetector { } public ExtractedFields detect() { - Set fields = fieldCapabilitiesResponse.get().keySet(); + Set fields = new HashSet<>(fieldCapabilitiesResponse.get().keySet()); fields.removeAll(IGNORE_FIELDS); checkResultsFieldIsNotPresent(fields, index); From 901bab242c65774e2ac13175c23380a3909b84d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Thu, 25 Apr 2019 13:57:31 +0200 Subject: [PATCH 38/67] [FEATURE][ML] Client-side data structures for DataFrame analytics config. (#41438) Implement data model classes for DataFrameAnalytics High-Level Rest Client. --- .../ml/dataframe/DataFrameAnalysis.java | 31 +++ .../dataframe/DataFrameAnalyticsConfig.java | 232 ++++++++++++++++++ .../ml/dataframe/DataFrameAnalyticsDest.java | 114 +++++++++ .../dataframe/DataFrameAnalyticsSource.java | 112 +++++++++ ...ataFrameAnalysisNamedXContentProvider.java | 39 +++ .../client/ml/dataframe/OutlierDetection.java | 153 ++++++++++++ .../client/ml/dataframe/QueryConfig.java | 88 +++++++ .../DataFrameAnalyticsConfigTests.java | 88 +++++++ .../DataFrameAnalyticsDestTests.java | 47 ++++ .../DataFrameAnalyticsSourceTests.java | 69 ++++++ .../ml/dataframe/OutlierDetectionTests.java | 68 +++++ .../client/ml/dataframe/QueryConfigTests.java | 62 +++++ 12 files changed, 1103 insertions(+) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java new file mode 100644 index 0000000000000..585d135700aa4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java @@ -0,0 +1,31 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +import java.util.Map; + +public interface DataFrameAnalysis extends ToXContentObject { + + String getName(); + + Map getParams(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java new file mode 100644 index 0000000000000..a5ede0e9128be --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java @@ -0,0 +1,232 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING; +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.VALUE; + +public class DataFrameAnalyticsConfig implements ToXContentObject { + + public static DataFrameAnalyticsConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static Builder builder(String id) { + return new Builder(id); + } + + private static final String NAME = "data_frame_analytics_config"; + + private static final ParseField ID = new ParseField("id"); + private static final ParseField SOURCE = new ParseField("source"); + private static final ParseField DEST = new ParseField("dest"); + private static final ParseField ANALYSIS = new ParseField("analysis"); + private static final ParseField ANALYZED_FIELDS = new ParseField("analyzed_fields"); + private static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); + + private static ObjectParser PARSER = new ObjectParser<>(NAME, true, Builder::new); + + static { + PARSER.declareString(Builder::setId, ID); + PARSER.declareObject(Builder::setSource, (p, c) -> DataFrameAnalyticsSource.fromXContent(p), SOURCE); + PARSER.declareObject(Builder::setDest, (p, c) -> DataFrameAnalyticsDest.fromXContent(p), DEST); + PARSER.declareObject(Builder::setAnalysis, (p, c) -> parseAnalysis(p), ANALYSIS); + PARSER.declareField(Builder::setAnalyzedFields, + (p, c) -> FetchSourceContext.fromXContent(p), + ANALYZED_FIELDS, + OBJECT_ARRAY_BOOLEAN_OR_STRING); + PARSER.declareField(Builder::setModelMemoryLimit, + (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), MODEL_MEMORY_LIMIT.getPreferredName()), MODEL_MEMORY_LIMIT, VALUE); + } + + private static DataFrameAnalysis parseAnalysis(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + DataFrameAnalysis analysis = parser.namedObject(DataFrameAnalysis.class, parser.currentName(), true); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return analysis; + } + + private final String id; + private final DataFrameAnalyticsSource source; + private final DataFrameAnalyticsDest dest; + private final DataFrameAnalysis analysis; + private final FetchSourceContext analyzedFields; + private final ByteSizeValue modelMemoryLimit; + + private DataFrameAnalyticsConfig(String id, DataFrameAnalyticsSource source, DataFrameAnalyticsDest dest, DataFrameAnalysis analysis, + @Nullable FetchSourceContext analyzedFields, @Nullable ByteSizeValue modelMemoryLimit) { + this.id = Objects.requireNonNull(id); + this.source = Objects.requireNonNull(source); + this.dest = Objects.requireNonNull(dest); + this.analysis = Objects.requireNonNull(analysis); + this.analyzedFields = analyzedFields; + this.modelMemoryLimit = modelMemoryLimit; + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsSource getSource() { + return source; + } + + public DataFrameAnalyticsDest getDest() { + return dest; + } + + public DataFrameAnalysis getAnalysis() { + return analysis; + } + + public FetchSourceContext getAnalyzedFields() { + return analyzedFields; + } + + public ByteSizeValue getModelMemoryLimit() { + return modelMemoryLimit; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ID.getPreferredName(), id); + builder.field(SOURCE.getPreferredName(), source); + builder.field(DEST.getPreferredName(), dest); + builder.startObject(ANALYSIS.getPreferredName()); + builder.field(analysis.getName(), analysis); + builder.endObject(); + if (analyzedFields != null) { + builder.field(ANALYZED_FIELDS.getPreferredName(), analyzedFields); + } + if (modelMemoryLimit != null) { + builder.field(MODEL_MEMORY_LIMIT.getPreferredName(), modelMemoryLimit.getStringRep()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsConfig other = (DataFrameAnalyticsConfig) o; + return Objects.equals(id, other.id) + && Objects.equals(source, other.source) + && Objects.equals(dest, other.dest) + && Objects.equals(analysis, other.analysis) + && Objects.equals(analyzedFields, other.analyzedFields) + && Objects.equals(modelMemoryLimit, other.modelMemoryLimit); + } + + @Override + public int hashCode() { + return Objects.hash(id, source, dest, analysis, analyzedFields, getModelMemoryLimit()); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add("id", id) + .add("source", source) + .add("dest", dest) + .add("analysis", analysis) + .add("analyzedFields", analyzedFields) + .add("modelMemoryLimit", modelMemoryLimit) + .toString(); + } + + public static class Builder { + + private String id; + private DataFrameAnalyticsSource source; + private DataFrameAnalyticsDest dest; + private DataFrameAnalysis analysis; + private FetchSourceContext analyzedFields; + private ByteSizeValue modelMemoryLimit; + + private Builder() {} + + private Builder(String id) { + setId(id); + } + + public Builder(DataFrameAnalyticsConfig config) { + this.id = config.id; + this.source = new DataFrameAnalyticsSource(config.source); + this.dest = new DataFrameAnalyticsDest(config.dest); + this.analysis = config.analysis; + if (config.analyzedFields != null) { + this.analyzedFields = new FetchSourceContext(true, config.analyzedFields.includes(), config.analyzedFields.excludes()); + } + this.modelMemoryLimit = config.modelMemoryLimit; + } + + public Builder setId(String id) { + this.id = Objects.requireNonNull(id); + return this; + } + + public Builder setSource(DataFrameAnalyticsSource source) { + this.source = Objects.requireNonNull(source); + return this; + } + + public Builder setDest(DataFrameAnalyticsDest dest) { + this.dest = Objects.requireNonNull(dest); + return this; + } + + public Builder setAnalysis(DataFrameAnalysis analysis) { + this.analysis = Objects.requireNonNull(analysis); + return this; + } + + public Builder setAnalyzedFields(FetchSourceContext fields) { + this.analyzedFields = fields; + return this; + } + + public Builder setModelMemoryLimit(ByteSizeValue modelMemoryLimit) { + this.modelMemoryLimit = modelMemoryLimit; + return this; + } + + public DataFrameAnalyticsConfig build() { + return new DataFrameAnalyticsConfig(id, source, dest, analysis, analyzedFields, modelMemoryLimit); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java new file mode 100644 index 0000000000000..c15ca05c969b0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java @@ -0,0 +1,114 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class DataFrameAnalyticsDest implements ToXContentObject { + + public static DataFrameAnalyticsDest fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField RESULTS_FIELD = new ParseField("results_field"); + + private static ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("data_frame_analytics_dest", true, + (args) -> { + String index = (String) args[0]; + String resultsField = (String) args[1]; + return new DataFrameAnalyticsDest(index, resultsField); + }); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD); + } + + private final String index; + private final String resultsField; + + public DataFrameAnalyticsDest(String index) { + this(index, null); + } + + public DataFrameAnalyticsDest(String index, @Nullable String resultsField) { + this.index = requireNonNull(index); + this.resultsField = resultsField; + } + + public DataFrameAnalyticsDest(DataFrameAnalyticsDest other) { + this(other.index, other.resultsField); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsDest other = (DataFrameAnalyticsDest) o; + return Objects.equals(index, other.index) + && Objects.equals(resultsField, other.resultsField); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add("index", index) + .add("resultsField", resultsField) + .toString(); + } + + @Override + public int hashCode() { + return Objects.hash(index, resultsField); + } + + public String getIndex() { + return index; + } + + public String getResultsField() { + return resultsField; + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java new file mode 100644 index 0000000000000..f18f65ad3f66c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java @@ -0,0 +1,112 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class DataFrameAnalyticsSource implements ToXContentObject { + + public static DataFrameAnalyticsSource fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField QUERY = new ParseField("query"); + + private static ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("data_frame_analytics_source", true, + (args) -> { + String index = (String) args[0]; + QueryConfig queryConfig = (QueryConfig) args[1]; + return new DataFrameAnalyticsSource(index, queryConfig); + }); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> QueryConfig.fromXContent(p), QUERY); + } + + private final String index; + private final QueryConfig queryConfig; + + public DataFrameAnalyticsSource(String index) { + this(index, null); + } + + public DataFrameAnalyticsSource(String index, @Nullable QueryConfig queryConfig) { + this.index = Objects.requireNonNull(index); + this.queryConfig = queryConfig; + } + + public DataFrameAnalyticsSource(DataFrameAnalyticsSource other) { + this(other.index, new QueryConfig(other.queryConfig)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + if (queryConfig != null) { + builder.field(QUERY.getPreferredName(), queryConfig.getQuery()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsSource other = (DataFrameAnalyticsSource) o; + return Objects.equals(index, other.index) + && Objects.equals(queryConfig, other.queryConfig); + } + + @Override + public int hashCode() { + return Objects.hash(index, queryConfig); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add("index", index) + .add("queryConfig", queryConfig) + .toString(); + } + + public String getIndex() { + return index; + } + + public QueryConfig getQueryConfig() { + return queryConfig; + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java new file mode 100644 index 0000000000000..3b3a28eb3a8b5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.ArrayList; +import java.util.List; + +public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + List namedXContent = new ArrayList<>(); + + namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> { + return OutlierDetection.fromXContent(p); + })); + + return namedXContent; + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java new file mode 100644 index 0000000000000..64c6a098f38b1 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java @@ -0,0 +1,153 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +public class OutlierDetection implements DataFrameAnalysis { + + public static OutlierDetection fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + static final ParseField NAME = new ParseField("outlier_detection"); + static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); + static final ParseField METHOD = new ParseField("method"); + + private static ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME.getPreferredName(), true, + (args) -> { + Integer nNeighbors = (Integer) args[0]; + Method method = (Method) args[1]; + return new OutlierDetection(nNeighbors, method); + }); + + private final Integer nNeighbors; + private final Method method; + + static { + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), N_NEIGHBORS); + PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return Method.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, METHOD, ObjectParser.ValueType.STRING); + } + + /** + * Constructs the outlier detection configuration + * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. + * @param method The method. Leave unspecified for a dynamic mixture of methods. + */ + public OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method) { + if (nNeighbors != null && nNeighbors <= 0) { + throw new IllegalArgumentException("[" + N_NEIGHBORS.getPreferredName() + "] must be a positive integer"); + } + + this.nNeighbors = nNeighbors; + this.method = method; + } + + /** + * Constructs the default outlier detection configuration + */ + public OutlierDetection() { + this(null, null); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (nNeighbors != null) { + builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + builder.field(METHOD.getPreferredName(), method); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + OutlierDetection other = (OutlierDetection) o; + return Objects.equals(nNeighbors, other.nNeighbors) + && Objects.equals(method, other.method); + } + + @Override + public int hashCode() { + return Objects.hash(nNeighbors, method); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add("nNeighbors", nNeighbors) + .add("method", method) + .toString(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public Map getParams() { + Map params = new HashMap<>(); + if (nNeighbors != null) { + params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + params.put(METHOD.getPreferredName(), method); + } + return params; + } + + public enum Method { + LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; + + public static Method fromString(String value) { + return Method.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java new file mode 100644 index 0000000000000..f3694dd50cb54 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java @@ -0,0 +1,88 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +/** + * Object for encapsulating the desired Query for a DataFrameAnalysis + */ +public class QueryConfig implements ToXContentObject { + + public static QueryConfig fromXContent(XContentParser parser) throws IOException { + QueryBuilder query = AbstractQueryBuilder.parseInnerQueryBuilder(parser); + return new QueryConfig(query); + } + + private final QueryBuilder query; + + public QueryConfig(QueryBuilder query) { + this.query = requireNonNull(query); + } + + public QueryConfig(QueryConfig queryConfig) { + this(requireNonNull(queryConfig).query); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + query.toXContent(builder, params); + return builder; + } + + public QueryBuilder getQuery() { + return query; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + QueryConfig other = (QueryConfig) o; + return Objects.equals(query, other.query); + } + + @Override + public int hashCode() { + return Objects.hash(query); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add("query", query) + .toString(); + } + + public boolean isValid() { + return query != null; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java new file mode 100644 index 0000000000000..4eba642401054 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -0,0 +1,88 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; + +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSourceTests.randomSourceConfig; +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDestTests.randomDestConfig; +import static org.elasticsearch.client.ml.dataframe.OutlierDetectionTests.randomOutlierDetection; + +public class DataFrameAnalyticsConfigTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsConfig randomDataFrameAnalyticsConfig() { + DataFrameAnalyticsConfig.Builder builder = + DataFrameAnalyticsConfig.builder(randomAlphaOfLengthBetween(1, 10)) + .setSource(randomSourceConfig()) + .setDest(randomDestConfig()) + .setAnalysis(randomOutlierDetection()); + if (randomBoolean()) { + builder.setAnalyzedFields(new FetchSourceContext(true, + generateRandomStringArray(10, 10, false, false), + generateRandomStringArray(10, 10, false, false))); + } + if (randomBoolean()) { + builder.setModelMemoryLimit(new ByteSizeValue(randomIntBetween(1, 16), randomFrom(ByteSizeUnit.MB, ByteSizeUnit.GB))); + } + return builder.build(); + } + + @Override + protected DataFrameAnalyticsConfig createTestInstance() { + return randomDataFrameAnalyticsConfig(); + } + + @Override + protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java new file mode 100644 index 0000000000000..8e208cfbc7f99 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class DataFrameAnalyticsDestTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsDest randomDestConfig() { + return new DataFrameAnalyticsDest(randomAlphaOfLengthBetween(1, 10), randomAlphaOfLengthBetween(1, 10)); + } + + @Override + protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsDest.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected DataFrameAnalyticsDest createTestInstance() { + return randomDestConfig(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java new file mode 100644 index 0000000000000..0898afb5b7781 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java @@ -0,0 +1,69 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; + +import static java.util.Collections.emptyList; +import static org.elasticsearch.client.ml.dataframe.QueryConfigTests.randomQueryConfig; + + +public class DataFrameAnalyticsSourceTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsSource randomSourceConfig() { + return new DataFrameAnalyticsSource( + randomAlphaOfLengthBetween(1, 10), + randomBoolean() ? null : randomQueryConfig()); + } + + @Override + protected DataFrameAnalyticsSource doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsSource.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only as QueryConfig stores a Map + return field -> !field.isEmpty(); + } + + @Override + protected DataFrameAnalyticsSource createTestInstance() { + return randomSourceConfig(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java new file mode 100644 index 0000000000000..96a8f7126b08b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java @@ -0,0 +1,68 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class OutlierDetectionTests extends AbstractXContentTestCase { + + public static OutlierDetection randomOutlierDetection() { + return new OutlierDetection( + randomBoolean() ? null : randomIntBetween(1, 20), + randomBoolean() ? null : randomFrom(OutlierDetection.Method.values())); + } + + @Override + protected OutlierDetection doParseInstance(XContentParser parser) throws IOException { + return OutlierDetection.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected OutlierDetection createTestInstance() { + return randomOutlierDetection(); + } + + public void testGetParams_GivenDefaults() { + OutlierDetection outlierDetection = new OutlierDetection(); + assertThat(outlierDetection.getParams().isEmpty(), is(true)); + } + + public void testGetParams_GivenExplicitValues() { + OutlierDetection outlierDetection = new OutlierDetection(42, OutlierDetection.Method.LDOF); + + Map params = outlierDetection.getParams(); + + assertThat(params.size(), equalTo(2)); + assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42)); + assertThat(params.get(OutlierDetection.METHOD.getPreferredName()), equalTo(OutlierDetection.Method.LDOF)); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java new file mode 100644 index 0000000000000..1e66445100b3e --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java @@ -0,0 +1,62 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +import static java.util.Collections.emptyList; + +public class QueryConfigTests extends AbstractXContentTestCase { + + public static QueryConfig randomQueryConfig() { + QueryBuilder queryBuilder = randomBoolean() ? new MatchAllQueryBuilder() : new MatchNoneQueryBuilder(); + return new QueryConfig(queryBuilder); + } + + @Override + protected QueryConfig createTestInstance() { + return randomQueryConfig(); + } + + @Override + protected QueryConfig doParseInstance(XContentParser parser) throws IOException { + return QueryConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } +} From 2c40ecf56ddf2bd289b175640e2d72ed530ab8db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Fri, 26 Apr 2019 12:54:18 +0200 Subject: [PATCH 39/67] [FEATURE][ML] Client-side API methods for DataFrame analytics config. (#41530) Implementation of Get/Put/Delete API methods on client side together with integration tests. relates to #elastic/ml-team/issues/102 --- .../client/MLRequestConverters.java | 46 +++++ .../client/MachineLearningClient.java | 129 ++++++++++++++ .../client/ValidationException.java | 7 + .../ml/DeleteDataFrameAnalyticsRequest.java | 64 +++++++ .../ml/GetDataFrameAnalyticsRequest.java | 83 +++++++++ .../ml/GetDataFrameAnalyticsResponse.java | 74 ++++++++ .../ml/PutDataFrameAnalyticsRequest.java | 70 ++++++++ .../ml/PutDataFrameAnalyticsResponse.java | 57 ++++++ .../client/ml/dataframe/OutlierDetection.java | 2 +- .../client/ml/job/util/PageParams.java | 4 +- ...icsearch.plugins.spi.NamedXContentProvider | 3 +- .../client/MLRequestConvertersTests.java | 53 +++++- .../client/MachineLearningIT.java | 164 ++++++++++++++++++ .../client/MlTestStateCleaner.java | 13 ++ .../client/RestHighLevelClientTests.java | 8 +- .../DeleteDataFrameAnalyticsRequestTests.java | 39 +++++ .../ml/GetDataFrameAnalyticsRequestTests.java | 39 +++++ .../ml/PutDataFrameAnalyticsRequestTests.java | 74 ++++++++ 18 files changed, 922 insertions(+), 7 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 073b92f84a3a3..8350b5e496787 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -31,6 +31,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteFilterRequest; @@ -44,6 +45,7 @@ import org.elasticsearch.client.ml.GetCalendarEventsRequest; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCategoriesRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; import org.elasticsearch.client.ml.GetFiltersRequest; @@ -60,6 +62,7 @@ import org.elasticsearch.client.ml.PreviewDatafeedRequest; import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutJobRequest; @@ -576,6 +579,49 @@ static Request deleteCalendarEvent(DeleteCalendarEventRequest deleteCalendarEven return new Request(HttpDelete.METHOD_NAME, endpoint); } + static Request putDataFrameAnalytics(PutDataFrameAnalyticsRequest putRequest) throws IOException { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml") + .addPathPartAsIs("data_frame") + .addPathPartAsIs("analytics") + .addPathPart(putRequest.getConfig().getId()) + .build(); + Request request = new Request(HttpPut.METHOD_NAME, endpoint); + request.setEntity(createEntity(putRequest, REQUEST_BODY_CONTENT_TYPE)); + return request; + } + + static Request getDataFrameAnalytics(GetDataFrameAnalyticsRequest getRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml") + .addPathPartAsIs("data_frame") + .addPathPartAsIs("analytics") + .addPathPart(Strings.collectionToCommaDelimitedString(getRequest.getIds())) + .build(); + Request request = new Request(HttpGet.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(request); + if (getRequest.getPageParams() != null) { + PageParams pageParams = getRequest.getPageParams(); + if (pageParams.getFrom() != null) { + params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); + } + if (pageParams.getSize() != null) { + params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); + } + } + return request; + } + + static Request deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest deleteRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml") + .addPathPartAsIs("data_frame") + .addPathPartAsIs("analytics") + .addPathPart(deleteRequest.getId()) + .build(); + return new Request(HttpDelete.METHOD_NAME, endpoint); + } + static Request putFilter(PutFilterRequest putFilterRequest) throws IOException { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 2e359931c1025..26e0505e6261c 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -25,6 +25,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -47,6 +48,8 @@ import org.elasticsearch.client.ml.GetCalendarsResponse; import org.elasticsearch.client.ml.GetCategoriesRequest; import org.elasticsearch.client.ml.GetCategoriesResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -78,6 +81,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -1877,4 +1882,128 @@ public void setUpgradeModeAsync(SetUpgradeModeRequest request, RequestOptions op listener, Collections.emptySet()); } + + /** + * Creates a new Data Frame Analytics config + *

+ * For additional info + * see PUT Data Frame Analytics documentation + * + * @param request The {@link PutDataFrameAnalyticsRequest} containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return The {@link PutDataFrameAnalyticsResponse} containing the created + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public PutDataFrameAnalyticsResponse putDataFrameAnalytics(PutDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::putDataFrameAnalytics, + options, + PutDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Creates a new Data Frame Analytics config asynchronously and notifies listener upon completion + *

+ * For additional info + * see PUT Data Frame Analytics documentation + * + * @param request The {@link PutDataFrameAnalyticsRequest} containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void putDataFrameAnalyticsAsync(PutDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::putDataFrameAnalytics, + options, + PutDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Gets a single or multiple Data Frame Analytics configs + *

+ * For additional info + * see GET Data Frame Analytics documentation + * + * @param request The {@link GetDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link GetDataFrameAnalyticsResponse} response object containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} objects + */ + public GetDataFrameAnalyticsResponse getDataFrameAnalytics(GetDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::getDataFrameAnalytics, + options, + GetDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Gets a single or multiple Data Frame Analytics configs asynchronously and notifies listener upon completion + *

+ * For additional info + * see GET Data Frame Analytics documentation + * + * @param request The {@link GetDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void getDataFrameAnalyticsAsync(GetDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::getDataFrameAnalytics, + options, + GetDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + + /** + * Deletes the given Data Frame Analytics config + *

+ * For additional info + * see DELETE Data Frame Analytics documentation + * + * @param request The {@link DeleteDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return action acknowledgement + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public AcknowledgedResponse deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::deleteDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Deletes the given Data Frame Analytics config asynchronously and notifies listener upon completion + *

+ * For additional info + * see DELETE Data Frame Analytics documentation + * + * @param request The {@link DeleteDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void deleteDataFrameAnalyticsAsync(DeleteDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::deleteDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + listener, + Collections.emptySet()); + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ValidationException.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ValidationException.java index 730ea7e95de12..c2988f3f7f416 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ValidationException.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ValidationException.java @@ -27,6 +27,13 @@ * Encapsulates an accumulation of validation errors */ public class ValidationException extends IllegalArgumentException { + + public static ValidationException withError(String error) { + ValidationException e = new ValidationException(); + e.addValidationError(error); + return e; + } + private final List validationErrors = new ArrayList<>(); /** diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..f03466632304d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; + +import java.util.Objects; +import java.util.Optional; + +/** + * Request to delete a data frame analytics config + */ +public class DeleteDataFrameAnalyticsRequest implements Validatable { + + private final String id; + + public DeleteDataFrameAnalyticsRequest(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DeleteDataFrameAnalyticsRequest other = (DeleteDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..9da8d5bbce7d6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java @@ -0,0 +1,83 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.job.util.PageParams; +import org.elasticsearch.common.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +public class GetDataFrameAnalyticsRequest implements Validatable { + + private final List ids; + private PageParams pageParams; + + /** + * Helper method to create a request that will get ALL Data Frame Analytics + * @return new {@link GetDataFrameAnalyticsRequest} object for the id "_all" + */ + public static GetDataFrameAnalyticsRequest getAllDataFrameAnalyticsRequest() { + return new GetDataFrameAnalyticsRequest("_all"); + } + + public GetDataFrameAnalyticsRequest(String... ids) { + this.ids = Arrays.asList(ids); + } + + public List getIds() { + return ids; + } + + public PageParams getPageParams() { + return pageParams; + } + + public void setPageParams(@Nullable PageParams pageParams) { + this.pageParams = pageParams; + } + + @Override + public Optional validate() { + if (ids == null || ids.isEmpty()) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsRequest other = (GetDataFrameAnalyticsRequest) o; + return Objects.equals(ids, other.ids) + && Objects.equals(pageParams, other.pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(ids, pageParams); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..76996e9d4d0b6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class GetDataFrameAnalyticsResponse { + + public static final ParseField DATA_FRAME_ANALYTICS = new ParseField("data_frame_analytics"); + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "get_data_frame_analytics", + true, + args -> new GetDataFrameAnalyticsResponse((List) args[0])); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> DataFrameAnalyticsConfig.fromXContent(p), DATA_FRAME_ANALYTICS); + } + + public static GetDataFrameAnalyticsResponse fromXContent(final XContentParser parser) { + return PARSER.apply(parser, null); + } + + private List analytics; + + public GetDataFrameAnalyticsResponse(List analytics) { + this.analytics = analytics; + } + + public List getAnalytics() { + return analytics; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsResponse other = (GetDataFrameAnalyticsResponse) o; + return Objects.equals(this.analytics, other.analytics); + } + + @Override + public int hashCode() { + return Objects.hash(analytics); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..14950a74c9187 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java @@ -0,0 +1,70 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; + +public class PutDataFrameAnalyticsRequest implements ToXContentObject, Validatable { + + private final DataFrameAnalyticsConfig config; + + public PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfig config) { + this.config = config; + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public Optional validate() { + if (config == null) { + return Optional.of(ValidationException.withError("put requires a non-null data frame analytics config")); + } + return Optional.empty(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return config.toXContent(builder, params); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PutDataFrameAnalyticsRequest other = (PutDataFrameAnalyticsRequest) o; + return Objects.equals(config, other.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..e6c4be15987d4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class PutDataFrameAnalyticsResponse { + + public static PutDataFrameAnalyticsResponse fromXContent(XContentParser parser) throws IOException { + return new PutDataFrameAnalyticsResponse(DataFrameAnalyticsConfig.fromXContent(parser)); + } + + private final DataFrameAnalyticsConfig config; + + public PutDataFrameAnalyticsResponse(DataFrameAnalyticsConfig config) { + this.config = config; + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PutDataFrameAnalyticsResponse other = (PutDataFrameAnalyticsResponse) o; + return Objects.equals(config, other.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java index 64c6a098f38b1..be5334c14d518 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java @@ -39,7 +39,7 @@ public static OutlierDetection fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - static final ParseField NAME = new ParseField("outlier_detection"); + public static final ParseField NAME = new ParseField("outlier_detection"); static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); static final ParseField METHOD = new ParseField("method"); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/util/PageParams.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/util/PageParams.java index 52d54188f7007..b556fd3ce0ad3 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/util/PageParams.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/util/PageParams.java @@ -57,11 +57,11 @@ public PageParams(@Nullable Integer from, @Nullable Integer size) { this.size = size; } - public int getFrom() { + public Integer getFrom() { return from; } - public int getSize() { + public Integer getSize() { return size; } diff --git a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider index 4204a868246a5..342c606a540a6 100644 --- a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider +++ b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider @@ -1 +1,2 @@ -org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider \ No newline at end of file +org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider +org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider \ No newline at end of file diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 11faaf879729d..6bb16b8c1c3c2 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteFilterRequest; @@ -41,6 +42,7 @@ import org.elasticsearch.client.ml.GetCalendarEventsRequest; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCategoriesRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; import org.elasticsearch.client.ml.GetFiltersRequest; @@ -57,6 +59,7 @@ import org.elasticsearch.client.ml.PreviewDatafeedRequest; import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutJobRequest; @@ -74,6 +77,8 @@ import org.elasticsearch.client.ml.calendars.ScheduledEventTests; import org.elasticsearch.client.ml.datafeed.DatafeedConfig; import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.Detector; @@ -84,23 +89,30 @@ import org.elasticsearch.client.ml.job.config.MlFilterTests; import org.elasticsearch.client.ml.job.util.PageParams; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.ESTestCase; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfigTests.randomDataFrameAnalyticsConfig; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; import static org.hamcrest.core.IsNull.nullValue; @@ -154,7 +166,6 @@ public void testGetJobStats() { assertEquals(Boolean.toString(true), request.getParameters().get("allow_no_jobs")); } - public void testOpenJob() throws Exception { String jobId = "some-job-id"; OpenJobRequest openJobRequest = new OpenJobRequest(jobId); @@ -669,6 +680,38 @@ public void testDeleteCalendarEvent() { assertEquals("/_ml/calendars/" + calendarId + "/events/" + eventId, request.getEndpoint()); } + public void testPutDataFrameAnalytics() throws IOException { + PutDataFrameAnalyticsRequest putRequest = new PutDataFrameAnalyticsRequest(randomDataFrameAnalyticsConfig()); + Request request = MLRequestConverters.putDataFrameAnalytics(putRequest); + assertEquals(HttpPut.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + putRequest.getConfig().getId(), request.getEndpoint()); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) { + DataFrameAnalyticsConfig parsedConfig = DataFrameAnalyticsConfig.fromXContent(parser); + assertThat(parsedConfig, equalTo(putRequest.getConfig())); + } + } + + public void testGetDataFrameAnalytics() { + String configId1 = randomAlphaOfLength(10); + String configId2 = randomAlphaOfLength(10); + String configId3 = randomAlphaOfLength(10); + GetDataFrameAnalyticsRequest getRequest = new GetDataFrameAnalyticsRequest(configId1, configId2, configId3); + getRequest.setPageParams(new PageParams(100, 300)); + + Request request = MLRequestConverters.getDataFrameAnalytics(getRequest); + assertEquals(HttpGet.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + configId1 + "," + configId2 + "," + configId3, request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"))); + assertNull(request.getEntity()); + } + + public void testDeleteDataFrameAnalytics() { + DeleteDataFrameAnalyticsRequest deleteRequest = new DeleteDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.deleteDataFrameAnalytics(deleteRequest); + assertEquals(HttpDelete.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + deleteRequest.getId(), request.getEndpoint()); + } + public void testPutFilter() throws IOException { MlFilter filter = MlFilterTests.createRandomBuilder("foo").build(); PutFilterRequest putFilterRequest = new PutFilterRequest(filter); @@ -835,6 +878,14 @@ public void testSetUpgradeMode() { assertThat(request.getParameters().get(SetUpgradeModeRequest.TIMEOUT.getPreferredName()), is("1h")); } + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } + private static Job createValidJob(String jobId) { AnalysisConfig.Builder analysisConfig = AnalysisConfig.builder(Collections.singletonList( Detector.builder().setFunction("count").build())); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index f7b7b148f660b..a5b8e33290fd4 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -33,6 +33,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -51,6 +52,8 @@ import org.elasticsearch.client.ml.GetCalendarEventsResponse; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCalendarsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -76,6 +79,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -102,6 +107,11 @@ import org.elasticsearch.client.ml.datafeed.DatafeedState; import org.elasticsearch.client.ml.datafeed.DatafeedStats; import org.elasticsearch.client.ml.datafeed.DatafeedUpdate; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.QueryConfig; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -113,9 +123,11 @@ import org.elasticsearch.client.ml.job.process.ModelSnapshot; import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.client.ml.job.util.PageParams; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.junit.After; @@ -1230,6 +1242,158 @@ public void testDeleteCalendarEvent() throws IOException { assertThat(remainingIds, not(hasItem(deletedEvent))); } + public void testPutDataFrameAnalyticsConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(new DataFrameAnalyticsSource("test-source-index")) + .setDest(new DataFrameAnalyticsDest("test-dest-index")) + .setAnalysis(new OutlierDetection()) + .build(); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + assertThat(createdConfig.getId(), equalTo(config.getId())); + assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex())); + assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value + assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex())); + assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value + assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis())); + assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields())); + assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value + } + + public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(new DataFrameAnalyticsSource("test-source-index")) + .setDest(new DataFrameAnalyticsDest("test-dest-index")) + .setAnalysis(new OutlierDetection()) + .build(); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(1)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), contains(createdConfig)); + } + + public void testGetDataFrameAnalyticsConfig_MultipleConfigs() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configIdPrefix = "test-config-"; + int numberOfConfigs = 10; + List createdConfigs = new ArrayList<>(); + for (int i = 0; i < numberOfConfigs; ++i) { + String configId = configIdPrefix + i; + DataFrameAnalyticsConfig config = + DataFrameAnalyticsConfig.builder(configId) + .setSource(new DataFrameAnalyticsSource("index-source-test")) + .setDest(new DataFrameAnalyticsDest("index-dest-test")) + .setAnalysis(new OutlierDetection()) + .build(); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + createdConfigs.add(createdConfig); + } + + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + GetDataFrameAnalyticsRequest.getAllDataFrameAnalyticsRequest(), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(numberOfConfigs)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), containsInAnyOrder(createdConfigs.toArray())); + } + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configIdPrefix + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(numberOfConfigs)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), containsInAnyOrder(createdConfigs.toArray())); + } + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configIdPrefix + "9", configIdPrefix + "1", configIdPrefix + "4"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(3)); + assertThat( + getDataFrameAnalyticsResponse.getAnalytics(), + containsInAnyOrder(createdConfigs.get(1), createdConfigs.get(4), createdConfigs.get(9))); + } + { + GetDataFrameAnalyticsRequest getDataFrameAnalyticsRequest = new GetDataFrameAnalyticsRequest(configIdPrefix + "*"); + getDataFrameAnalyticsRequest.setPageParams(new PageParams(3, 4)); + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + getDataFrameAnalyticsRequest, + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(4)); + assertThat( + getDataFrameAnalyticsResponse.getAnalytics(), + containsInAnyOrder(createdConfigs.get(3), createdConfigs.get(4), createdConfigs.get(5), createdConfigs.get(6))); + } + } + + public void testGetDataFrameAnalyticsConfig_ConfigNotFound() { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("config_that_does_not_exist"); + ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, + () -> execute(request, machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync)); + assertThat(exception.status().getStatus(), equalTo(404)); + } + + public void testDeleteDataFrameAnalyticsConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(new DataFrameAnalyticsSource("test-source-index")) + .setDest(new DataFrameAnalyticsDest("test-dest-index")) + .setAnalysis(new OutlierDetection()) + .build(); + + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(0)); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + + getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(1)); + + AcknowledgedResponse deleteDataFrameAnalyticsResponse = execute( + new DeleteDataFrameAnalyticsRequest(configId), + machineLearningClient::deleteDataFrameAnalytics, machineLearningClient::deleteDataFrameAnalyticsAsync); + assertTrue(deleteDataFrameAnalyticsResponse.isAcknowledged()); + + getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(0)); + } + + public void testDeleteDataFrameAnalyticsConfig_ConfigNotFound() { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("config_that_does_not_exist"); + ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, + () -> execute( + request, machineLearningClient::deleteDataFrameAnalytics, machineLearningClient::deleteDataFrameAnalyticsAsync)); + assertThat(exception.status().getStatus(), equalTo(404)); + } + public void testPutFilter() throws Exception { String filterId = "filter-job-test"; MlFilter mlFilter = MlFilter.builder(filterId) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java index c565af7c37202..f5776e99fd0eb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java @@ -20,14 +20,18 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.client.ml.CloseJobRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteJobRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetJobRequest; import org.elasticsearch.client.ml.GetJobResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.datafeed.DatafeedConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.job.config.Job; import java.io.IOException; @@ -48,6 +52,7 @@ public MlTestStateCleaner(Logger logger, MachineLearningClient mlClient) { public void clearMlMetadata() throws IOException { deleteAllDatafeeds(); deleteAllJobs(); + deleteAllDataFrameAnalytics(); } private void deleteAllDatafeeds() throws IOException { @@ -99,4 +104,12 @@ private void closeAllJobs() { throw new RuntimeException("Had to resort to force-closing jobs, something went wrong?", e1); } } + + private void deleteAllDataFrameAnalytics() throws IOException { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = + mlClient.getDataFrameAnalytics(GetDataFrameAnalyticsRequest.getAllDataFrameAnalyticsRequest(), RequestOptions.DEFAULT); + for (DataFrameAnalyticsConfig config : getDataFrameAnalyticsResponse.getAnalytics()) { + mlClient.deleteDataFrameAnalytics(new DeleteDataFrameAnalyticsRequest(config.getId()), RequestOptions.DEFAULT); + } + } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index ed5d7b66d80c1..b9dfa4274c28d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -56,6 +56,8 @@ import org.elasticsearch.client.indexlifecycle.SetPriorityAction; import org.elasticsearch.client.indexlifecycle.ShrinkAction; import org.elasticsearch.client.indexlifecycle.UnfollowAction; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -664,7 +666,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(20, namedXContents.size()); + assertEquals(21, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -674,7 +676,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 4, categories.size()); + assertEquals("Had: " + categories, 5, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -698,6 +700,8 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(ShrinkAction.NAME)); assertTrue(names.contains(FreezeAction.NAME)); assertTrue(names.contains(SetPriorityAction.NAME)); + assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class)); + assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..bc2ca2d954e76 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class DeleteDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new DeleteDataFrameAnalyticsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new DeleteDataFrameAnalyticsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new DeleteDataFrameAnalyticsRequest(null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..56d87ea6bef49 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class GetDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new GetDataFrameAnalyticsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new GetDataFrameAnalyticsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new GetDataFrameAnalyticsRequest(new String[0]).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..19bc68fa36118 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class PutDataFrameAnalyticsRequestTests extends AbstractXContentTestCase { + + public void testValidate_Ok() { + assertFalse(createTestInstance().validate().isPresent()); + } + + public void testValidate_Failure() { + Optional exception = new PutDataFrameAnalyticsRequest(null).validate(); + assertTrue(exception.isPresent()); + assertThat(exception.get().getMessage(), containsString("put requires a non-null data frame analytics config")); + } + + @Override + protected PutDataFrameAnalyticsRequest createTestInstance() { + return new PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfigTests.randomDataFrameAnalyticsConfig()); + } + + @Override + protected PutDataFrameAnalyticsRequest doParseInstance(XContentParser parser) throws IOException { + return new PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfig.fromXContent(parser)); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } +} From efd15567361fc66c0575097c635d3f000c3e8fd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Tue, 30 Apr 2019 05:56:21 +0200 Subject: [PATCH 40/67] Implement GetDataFrameAnalyticsStats API method (#41637) Implement GetDataFrameAnalyticsStats API method --- .../client/MLRequestConverters.java | 33 +++-- .../client/MachineLearningClient.java | 40 ++++++ .../ml/GetDataFrameAnalyticsRequest.java | 3 +- .../ml/GetDataFrameAnalyticsStatsRequest.java | 79 +++++++++++ .../GetDataFrameAnalyticsStatsResponse.java | 102 +++++++++++++ .../client/ml/NodeAttributes.java | 6 + .../ml/dataframe/DataFrameAnalyticsState.java | 34 +++++ .../ml/dataframe/DataFrameAnalyticsStats.java | 134 ++++++++++++++++++ .../client/MLRequestConvertersTests.java | 20 ++- .../client/MachineLearningIT.java | 124 ++++++++-------- ...etDataFrameAnalyticsStatsRequestTests.java | 39 +++++ .../DataFrameAnalyticsStatsTests.java | 66 +++++++++ 12 files changed, 604 insertions(+), 76 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 8350b5e496787..b869508aaf579 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -46,6 +46,7 @@ import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCategoriesRequest; import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; import org.elasticsearch.client.ml.GetFiltersRequest; @@ -581,9 +582,7 @@ static Request deleteCalendarEvent(DeleteCalendarEventRequest deleteCalendarEven static Request putDataFrameAnalytics(PutDataFrameAnalyticsRequest putRequest) throws IOException { String endpoint = new EndpointBuilder() - .addPathPartAsIs("_ml") - .addPathPartAsIs("data_frame") - .addPathPartAsIs("analytics") + .addPathPartAsIs("_ml", "data_frame", "analytics") .addPathPart(putRequest.getConfig().getId()) .build(); Request request = new Request(HttpPut.METHOD_NAME, endpoint); @@ -593,9 +592,7 @@ static Request putDataFrameAnalytics(PutDataFrameAnalyticsRequest putRequest) th static Request getDataFrameAnalytics(GetDataFrameAnalyticsRequest getRequest) { String endpoint = new EndpointBuilder() - .addPathPartAsIs("_ml") - .addPathPartAsIs("data_frame") - .addPathPartAsIs("analytics") + .addPathPartAsIs("_ml", "data_frame", "analytics") .addPathPart(Strings.collectionToCommaDelimitedString(getRequest.getIds())) .build(); Request request = new Request(HttpGet.METHOD_NAME, endpoint); @@ -612,11 +609,29 @@ static Request getDataFrameAnalytics(GetDataFrameAnalyticsRequest getRequest) { return request; } + static Request getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest getStatsRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(Strings.collectionToCommaDelimitedString(getStatsRequest.getIds())) + .addPathPartAsIs("_stats") + .build(); + Request request = new Request(HttpGet.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(request); + if (getStatsRequest.getPageParams() != null) { + PageParams pageParams = getStatsRequest.getPageParams(); + if (pageParams.getFrom() != null) { + params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); + } + if (pageParams.getSize() != null) { + params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); + } + } + return request; + } + static Request deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest deleteRequest) { String endpoint = new EndpointBuilder() - .addPathPartAsIs("_ml") - .addPathPartAsIs("data_frame") - .addPathPartAsIs("analytics") + .addPathPartAsIs("_ml", "data_frame", "analytics") .addPathPart(deleteRequest.getId()) .build(); return new Request(HttpDelete.METHOD_NAME, endpoint); diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 26e0505e6261c..61b53dac82eaf 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -50,6 +50,8 @@ import org.elasticsearch.client.ml.GetCategoriesResponse; import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -1966,6 +1968,44 @@ public void getDataFrameAnalyticsAsync(GetDataFrameAnalyticsRequest request, Req Collections.emptySet()); } + /** + * Gets the running statistics of a Data Frame Analytics + *

+ * For additional info + * see GET Data Frame Analytics Stats documentation + * + * @param request The {@link GetDataFrameAnalyticsStatsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link GetDataFrameAnalyticsStatsResponse} response object + */ + public GetDataFrameAnalyticsStatsResponse getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::getDataFrameAnalyticsStats, + options, + GetDataFrameAnalyticsStatsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Gets the running statistics of a Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see GET Data Frame Analytics Stats documentation + * + * @param request The {@link GetDataFrameAnalyticsStatsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void getDataFrameAnalyticsStatsAsync(GetDataFrameAnalyticsStatsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::getDataFrameAnalyticsStats, + options, + GetDataFrameAnalyticsStatsResponse::fromXContent, + listener, + Collections.emptySet()); + } /** * Deletes the given Data Frame Analytics config diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java index 9da8d5bbce7d6..eea04fb43d547 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java @@ -54,8 +54,9 @@ public PageParams getPageParams() { return pageParams; } - public void setPageParams(@Nullable PageParams pageParams) { + public GetDataFrameAnalyticsRequest setPageParams(@Nullable PageParams pageParams) { this.pageParams = pageParams; + return this; } @Override diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java new file mode 100644 index 0000000000000..044d62f229fe5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java @@ -0,0 +1,79 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.job.util.PageParams; +import org.elasticsearch.common.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +/** + * Request to get data frame analytics stats + */ +public class GetDataFrameAnalyticsStatsRequest implements Validatable { + + private final List ids; + private PageParams pageParams; + + public GetDataFrameAnalyticsStatsRequest(String... ids) { + this.ids = Arrays.asList(ids); + } + + public List getIds() { + return ids; + } + + public PageParams getPageParams() { + return pageParams; + } + + public GetDataFrameAnalyticsStatsRequest setPageParams(@Nullable PageParams pageParams) { + this.pageParams = pageParams; + return this; + } + + @Override + public Optional validate() { + if (ids == null || ids.isEmpty()) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsStatsRequest other = (GetDataFrameAnalyticsStatsRequest) o; + return Objects.equals(ids, other.ids) + && Objects.equals(pageParams, other.pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(ids, pageParams); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java new file mode 100644 index 0000000000000..5391a576e98b0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java @@ -0,0 +1,102 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.client.dataframe.AcknowledgedTasksResponse; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class GetDataFrameAnalyticsStatsResponse { + + public static GetDataFrameAnalyticsStatsResponse fromXContent(XContentParser parser) { + return GetDataFrameAnalyticsStatsResponse.PARSER.apply(parser, null); + } + + private static final ParseField DATA_FRAME_ANALYTICS = new ParseField("data_frame_analytics"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "get_data_frame_analytics_stats_response", true, + args -> new GetDataFrameAnalyticsStatsResponse( + (List) args[0], + (List) args[1], + (List) args[2])); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> DataFrameAnalyticsStats.fromXContent(p), DATA_FRAME_ANALYTICS); + PARSER.declareObjectArray( + optionalConstructorArg(), (p, c) -> TaskOperationFailure.fromXContent(p), AcknowledgedTasksResponse.TASK_FAILURES); + PARSER.declareObjectArray( + optionalConstructorArg(), (p, c) -> ElasticsearchException.fromXContent(p), AcknowledgedTasksResponse.NODE_FAILURES); + } + + private final List analyticsStats; + private final List taskFailures; + private final List nodeFailures; + + public GetDataFrameAnalyticsStatsResponse(List analyticsStats, + @Nullable List taskFailures, + @Nullable List nodeFailures) { + this.analyticsStats = analyticsStats; + this.taskFailures = taskFailures == null ? Collections.emptyList() : Collections.unmodifiableList(taskFailures); + this.nodeFailures = nodeFailures == null ? Collections.emptyList() : Collections.unmodifiableList(nodeFailures); + } + + public List getAnalyticsStats() { + return analyticsStats; + } + + public List getNodeFailures() { + return nodeFailures; + } + + public List getTaskFailures() { + return taskFailures; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsStatsResponse other = (GetDataFrameAnalyticsStatsResponse) o; + return Objects.equals(analyticsStats, other.analyticsStats) + && Objects.equals(nodeFailures, other.nodeFailures) + && Objects.equals(taskFailures, other.taskFailures); + } + + @Override + public int hashCode() { + return Objects.hash(analyticsStats, nodeFailures, taskFailures); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java index 892df340abd6b..a0f0d25f2ca01 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.ml; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -147,4 +148,9 @@ public boolean equals(Object other) { Objects.equals(transportAddress, that.transportAddress) && Objects.equals(attributes, that.attributes); } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java new file mode 100644 index 0000000000000..6ee349b8e8d38 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java @@ -0,0 +1,34 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import java.util.Locale; + +public enum DataFrameAnalyticsState { + STARTED, REINDEXING, ANALYZING, STOPPING, STOPPED; + + public static DataFrameAnalyticsState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public String value() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java new file mode 100644 index 0000000000000..ef9dced4263cc --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -0,0 +1,134 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.client.ml.NodeAttributes; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class DataFrameAnalyticsStats { + + public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + static final ParseField ID = new ParseField("id"); + static final ParseField STATE = new ParseField("state"); + static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + static final ParseField NODE = new ParseField("node"); + static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "data_frame_analytics_stats", true, + args -> new DataFrameAnalyticsStats( + (String) args[0], + (DataFrameAnalyticsState) args[1], + (Integer) args[2], + (NodeAttributes) args[3], + (String) args[4])); + + static { + PARSER.declareString(constructorArg(), ID); + PARSER.declareField(constructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return DataFrameAnalyticsState.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, STATE, ObjectParser.ValueType.STRING); + PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); + PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); + PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION); + } + + private final String id; + private final DataFrameAnalyticsState state; + private final Integer progressPercent; + private final NodeAttributes node; + private final String assignmentExplanation; + + public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable Integer progressPercent, + @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { + this.id = id; + this.state = state; + this.progressPercent = progressPercent; + this.node = node; + this.assignmentExplanation = assignmentExplanation; + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsState getState() { + return state; + } + + public Integer getProgressPercent() { + return progressPercent; + } + + public NodeAttributes getNode() { + return node; + } + + public String getAssignmentExplanation() { + return assignmentExplanation; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsStats other = (DataFrameAnalyticsStats) o; + return Objects.equals(id, other.id) + && Objects.equals(state, other.state) + && Objects.equals(progressPercent, other.progressPercent) + && Objects.equals(node, other.node) + && Objects.equals(assignmentExplanation, other.assignmentExplanation); + } + + @Override + public int hashCode() { + return Objects.hash(id, state, progressPercent, node, assignmentExplanation); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add("id", id) + .add("state", state) + .add("progressPercent", progressPercent) + .add("node", node) + .add("assignmentExplanation", assignmentExplanation) + .toString(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 6bb16b8c1c3c2..6d9046aa6288d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -43,6 +43,7 @@ import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCategoriesRequest; import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; import org.elasticsearch.client.ml.GetFiltersRequest; @@ -695,8 +696,8 @@ public void testGetDataFrameAnalytics() { String configId1 = randomAlphaOfLength(10); String configId2 = randomAlphaOfLength(10); String configId3 = randomAlphaOfLength(10); - GetDataFrameAnalyticsRequest getRequest = new GetDataFrameAnalyticsRequest(configId1, configId2, configId3); - getRequest.setPageParams(new PageParams(100, 300)); + GetDataFrameAnalyticsRequest getRequest = new GetDataFrameAnalyticsRequest(configId1, configId2, configId3) + .setPageParams(new PageParams(100, 300)); Request request = MLRequestConverters.getDataFrameAnalytics(getRequest); assertEquals(HttpGet.METHOD_NAME, request.getMethod()); @@ -705,11 +706,26 @@ public void testGetDataFrameAnalytics() { assertNull(request.getEntity()); } + public void testGetDataFrameAnalyticsStats() { + String configId1 = randomAlphaOfLength(10); + String configId2 = randomAlphaOfLength(10); + String configId3 = randomAlphaOfLength(10); + GetDataFrameAnalyticsStatsRequest getStatsRequest = new GetDataFrameAnalyticsStatsRequest(configId1, configId2, configId3) + .setPageParams(new PageParams(100, 300)); + + Request request = MLRequestConverters.getDataFrameAnalyticsStats(getStatsRequest); + assertEquals(HttpGet.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + configId1 + "," + configId2 + "," + configId3 + "/_stats", request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"))); + assertNull(request.getEntity()); + } + public void testDeleteDataFrameAnalytics() { DeleteDataFrameAnalyticsRequest deleteRequest = new DeleteDataFrameAnalyticsRequest(randomAlphaOfLength(10)); Request request = MLRequestConverters.deleteDataFrameAnalytics(deleteRequest); assertEquals(HttpDelete.METHOD_NAME, request.getMethod()); assertEquals("/_ml/data_frame/analytics/" + deleteRequest.getId(), request.getEndpoint()); + assertNull(request.getEntity()); } public void testPutFilter() throws IOException { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index a5b8e33290fd4..597caf26ba27f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -54,6 +54,8 @@ import org.elasticsearch.client.ml.GetCalendarsResponse; import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -112,6 +114,8 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -540,18 +544,7 @@ public void testStartDatafeed() throws Exception { String indexName = "start_data_1"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); long now = (System.currentTimeMillis()/1000)*1000; @@ -623,18 +616,7 @@ public void testStopDatafeed() throws Exception { String indexName = "stop_data_1"; // Set up the index - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); // create the job and the datafeed Job job1 = buildJob(jobId1); @@ -696,18 +678,7 @@ public void testGetDatafeedStats() throws Exception { String indexName = "datafeed_stats_data_1"; // Set up the index - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); // create the job and the datafeed Job job1 = buildJob(jobId1); @@ -774,18 +745,7 @@ public void testPreviewDatafeed() throws Exception { String indexName = "preview_data_1"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); long now = (System.currentTimeMillis()/1000)*1000; @@ -838,21 +798,9 @@ public void testDeleteExpiredDataGivenNothingToDelete() throws Exception { } private String createExpiredData(String jobId) throws Exception { - String indexId = jobId + "-data"; + String indexName = jobId + "-data"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexId); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .field("format", "epoch_millis") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -865,7 +813,7 @@ private String createExpiredData(String jobId) throws Exception { long timestamp = nowMillis - TimeValue.timeValueHours(totalBuckets - bucket).getMillis(); int bucketRate = bucket == anomalousBucket ? anomalousRate : normalRate; for (int point = 0; point < bucketRate; point++) { - IndexRequest indexRequest = new IndexRequest(indexId); + IndexRequest indexRequest = new IndexRequest(indexName); indexRequest.source(XContentType.JSON, "timestamp", timestamp, "total", randomInt(1000)); bulk.add(indexRequest); } @@ -884,7 +832,7 @@ private String createExpiredData(String jobId) throws Exception { Job job = buildJobForExpiredDataTests(jobId); putJob(job); openJob(job); - String datafeedId = createAndPutDatafeed(jobId, indexId); + String datafeedId = createAndPutDatafeed(jobId, indexName); startDatafeed(datafeedId, String.valueOf(0), String.valueOf(nowMillis - TimeValue.timeValueHours(24).getMillis())); @@ -1351,6 +1299,39 @@ public void testGetDataFrameAnalyticsConfig_ConfigNotFound() { assertThat(exception.status().getStatus(), equalTo(404)); } + public void testGetDataFrameAnalyticsStats() throws Exception { + String sourceIndex = "get-stats-test-source-index"; + String destIndex = "get-stats-test-dest-index"; + createIndex(sourceIndex); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "get-stats-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(new DataFrameAnalyticsSource(sourceIndex)) + .setDest(new DataFrameAnalyticsDest(destIndex)) + .setAnalysis(new OutlierDetection()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + + GetDataFrameAnalyticsStatsResponse statsResponse = execute( + new GetDataFrameAnalyticsStatsRequest(configId), + machineLearningClient::getDataFrameAnalyticsStats, machineLearningClient::getDataFrameAnalyticsStatsAsync); + + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + assertThat(stats.getId(), equalTo(configId)); + assertThat(stats.getState(), equalTo(DataFrameAnalyticsState.STOPPED)); + assertNull(stats.getProgressPercent()); + assertNull(stats.getNode()); + assertNull(stats.getAssignmentExplanation()); + assertThat(statsResponse.getNodeFailures(), hasSize(0)); + assertThat(statsResponse.getTaskFailures(), hasSize(0)); + } + public void testDeleteDataFrameAnalyticsConfig() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String configId = "test-config"; @@ -1394,6 +1375,21 @@ public void testDeleteDataFrameAnalyticsConfig_ConfigNotFound() { assertThat(exception.status().getStatus(), equalTo(404)); } + private void createIndex(String indexName) throws IOException { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("timestamp") + .field("type", "date") + .endObject() + .startObject("total") + .field("type", "long") + .endObject() + .endObject() + .endObject()); + highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + } + public void testPutFilter() throws Exception { String filterId = "filter-job-test"; MlFilter mlFilter = MlFilter.builder(filterId) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java new file mode 100644 index 0000000000000..4e08d99eaa659 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class GetDataFrameAnalyticsStatsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new GetDataFrameAnalyticsStatsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new GetDataFrameAnalyticsStatsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new GetDataFrameAnalyticsStatsRequest(new String[0]).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java new file mode 100644 index 0000000000000..ed6e24f754d19 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -0,0 +1,66 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.client.ml.NodeAttributesTests; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester; + +public class DataFrameAnalyticsStatsTests extends ESTestCase { + + public void testFromXContent() throws IOException { + xContentTester(this::createParser, + DataFrameAnalyticsStatsTests::randomDataFrameAnalyticsStats, + DataFrameAnalyticsStatsTests::toXContent, + DataFrameAnalyticsStats::fromXContent) + .supportsUnknownFields(true) + .randomFieldsExcludeFilter(field -> field.startsWith("node.attributes")) + .test(); + } + + public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() { + return new DataFrameAnalyticsStats( + randomAlphaOfLengthBetween(1, 10), + randomFrom(DataFrameAnalyticsState.values()), + randomBoolean() ? null : randomIntBetween(0, 100), + randomBoolean() ? null : NodeAttributesTests.createRandom(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20)); + } + + public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder builder) throws IOException { + builder.startObject(); + builder.field(DataFrameAnalyticsStats.ID.getPreferredName(), stats.getId()); + builder.field(DataFrameAnalyticsStats.STATE.getPreferredName(), stats.getState().value()); + if (stats.getProgressPercent() != null) { + builder.field(DataFrameAnalyticsStats.PROGRESS_PERCENT.getPreferredName(), stats.getProgressPercent()); + } + if (stats.getNode() != null) { + builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode()); + } + if (stats.getAssignmentExplanation() != null) { + builder.field(DataFrameAnalyticsStats.ASSIGNMENT_EXPLANATION.getPreferredName(), stats.getAssignmentExplanation()); + } + builder.endObject(); + } +} From 75c7380bc8924404469f46f7b21bcab5d4927f5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Tue, 30 Apr 2019 08:32:12 +0200 Subject: [PATCH 41/67] [FEATURE][ML] Client-side "Start" API method for DataFrame analytics config. (#41570) Implementation of "Start" API method. --- .../client/MLRequestConverters.java | 15 ++++ .../client/MachineLearningClient.java | 41 ++++++++++ .../ml/StartDataFrameAnalyticsRequest.java | 79 +++++++++++++++++++ .../client/MLRequestConvertersTests.java | 19 +++++ .../StartDataFrameAnalyticsRequestTests.java | 43 ++++++++++ 5 files changed, 197 insertions(+) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index b869508aaf579..6a0235806e63b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -69,6 +69,7 @@ import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -629,6 +630,20 @@ static Request getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest getS return request; } + static Request startDataFrameAnalytics(StartDataFrameAnalyticsRequest startRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(startRequest.getId()) + .addPathPartAsIs("_start") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(request); + if (startRequest.getTimeout() != null) { + params.withTimeout(startRequest.getTimeout()); + } + return request; + } + static Request deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest deleteRequest) { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml", "data_frame", "analytics") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 61b53dac82eaf..215a674a88aad 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -94,6 +94,7 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; @@ -2007,6 +2008,46 @@ public void getDataFrameAnalyticsStatsAsync(GetDataFrameAnalyticsStatsRequest re Collections.emptySet()); } + /** + * Starts Data Frame Analytics + *

+ * For additional info + * see Start Data Frame Analytics documentation + * + * @param request The {@link StartDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return action acknowledgement + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public AcknowledgedResponse startDataFrameAnalytics(StartDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::startDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Starts Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see Start Data Frame Analytics documentation + * + * @param request The {@link StartDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void startDataFrameAnalyticsAsync(StartDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::startDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + listener, + Collections.emptySet()); + } + /** * Deletes the given Data Frame Analytics config *

diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..16e43180d57fe --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java @@ -0,0 +1,79 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.unit.TimeValue; + +import java.util.Objects; +import java.util.Optional; + +public class StartDataFrameAnalyticsRequest implements Validatable { + + private final String id; + private TimeValue timeout; + + public StartDataFrameAnalyticsRequest(String id) { + this(id, null); + } + + public StartDataFrameAnalyticsRequest(String id, @Nullable TimeValue timeout) { + this.id = id; + this.timeout = timeout; + } + + public String getId() { + return id; + } + + public TimeValue getTimeout() { + return timeout; + } + + public StartDataFrameAnalyticsRequest setTimeout(@Nullable TimeValue timeout) { + this.timeout = timeout; + return this; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StartDataFrameAnalyticsRequest other = (StartDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id) + && Objects.equals(timeout, other.timeout); + } + + @Override + public int hashCode() { + return Objects.hash(id, timeout); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 6d9046aa6288d..cdd3641f5a412 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -66,6 +66,7 @@ import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedRequestTests; import org.elasticsearch.client.ml.StopDatafeedRequest; @@ -720,6 +721,24 @@ public void testGetDataFrameAnalyticsStats() { assertNull(request.getEntity()); } + public void testStartDataFrameAnalytics() { + StartDataFrameAnalyticsRequest startRequest = new StartDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.startDataFrameAnalytics(startRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + startRequest.getId() + "/_start", request.getEndpoint()); + assertNull(request.getEntity()); + } + + public void testStartDataFrameAnalytics_WithTimeout() { + StartDataFrameAnalyticsRequest startRequest = new StartDataFrameAnalyticsRequest(randomAlphaOfLength(10)) + .setTimeout(TimeValue.timeValueMinutes(1)); + Request request = MLRequestConverters.startDataFrameAnalytics(startRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + startRequest.getId() + "/_start", request.getEndpoint()); + assertThat(request.getParameters(), hasEntry("timeout", "1m")); + assertNull(request.getEntity()); + } + public void testDeleteDataFrameAnalytics() { DeleteDataFrameAnalyticsRequest deleteRequest = new DeleteDataFrameAnalyticsRequest(randomAlphaOfLength(10)); Request request = MLRequestConverters.deleteDataFrameAnalytics(deleteRequest); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..97367730561cf --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class StartDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").validate()); + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo", null).validate()); + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo", TimeValue.ZERO).validate()); + } + + public void testValidate_Failure() { + assertThat(new StartDataFrameAnalyticsRequest(null, null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + assertThat(new StartDataFrameAnalyticsRequest(null, TimeValue.ZERO).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} From fa8d9f48573c2cb34cfdf3fb9ee93ddbc1220e52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Tue, 30 Apr 2019 22:59:05 +0200 Subject: [PATCH 42/67] Implement integration test for StartDataFrameAnalytics API method (#41664) --- .../client/MachineLearningIT.java | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 597caf26ba27f..fd9f3464bd34e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -28,6 +28,7 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.client.indices.CreateIndexRequest; +import org.elasticsearch.client.indices.GetIndexRequest; import org.elasticsearch.client.ml.CloseJobRequest; import org.elasticsearch.client.ml.CloseJobResponse; import org.elasticsearch.client.ml.DeleteCalendarEventRequest; @@ -92,6 +93,7 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; @@ -1332,6 +1334,50 @@ public void testGetDataFrameAnalyticsStats() throws Exception { assertThat(statsResponse.getTaskFailures(), hasSize(0)); } + public void testStartDataFrameAnalyticsConfig() throws Exception { + String sourceIndex = "start-test-source-index"; + String destIndex = "start-test-dest-index"; + createIndex(sourceIndex); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + + // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. + assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "start-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(new DataFrameAnalyticsSource(sourceIndex)) + .setDest(new DataFrameAnalyticsDest(destIndex)) + .setAnalysis(new OutlierDetection()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + + AcknowledgedResponse startDataFrameAnalyticsResponse = execute( + new StartDataFrameAnalyticsRequest(configId), + machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync); + assertTrue(startDataFrameAnalyticsResponse.isAcknowledged()); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STARTED)); + + // Wait for the analytics to stop. + assertBusy(() -> assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)), 30, TimeUnit.SECONDS); + + // Verify that the destination index got created. + assertTrue(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + } + + private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOException { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + GetDataFrameAnalyticsStatsResponse statsResponse = + machineLearningClient.getDataFrameAnalyticsStats(new GetDataFrameAnalyticsStatsRequest(configId), RequestOptions.DEFAULT); + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + return stats.getState(); + } + public void testDeleteDataFrameAnalyticsConfig() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String configId = "test-config"; From 362775bb5b0c6b1dad14d0dff9d6787909f4e009 Mon Sep 17 00:00:00 2001 From: Hendrik Muhs Date: Mon, 6 May 2019 14:41:29 +0200 Subject: [PATCH 43/67] [ML-DataFrame] add sync api (#41800) add synchronization API and add a implementation for time-based synchronization --- .../DataFrameNamedXContentProvider.java | 41 +++++ .../transforms/DataFrameTransformConfig.java | 44 +++++- .../dataframe/transforms/SyncConfig.java | 30 ++++ .../dataframe/transforms/TimeSyncConfig.java | 108 +++++++++++++ ...icsearch.plugins.spi.NamedXContentProvider | 1 + .../DataFrameRequestConvertersTests.java | 7 +- .../client/RestHighLevelClientTests.java | 8 +- .../GetDataFrameTransformResponseTests.java | 5 +- ...PreviewDataFrameTransformRequestTests.java | 6 +- .../PutDataFrameTransformRequestTests.java | 6 +- .../DataFrameTransformConfigTests.java | 16 +- .../transforms/TimeSyncConfigTests.java | 49 ++++++ .../transforms/hlrc/TimeSyncConfigTests.java | 59 +++++++ .../DataFrameTransformDocumentationIT.java | 5 +- .../xpack/core/XPackClientPlugin.java | 6 +- .../xpack/core/dataframe/DataFrameField.java | 4 + .../DataFrameNamedXContentProvider.java | 26 +++ .../transforms/DataFrameTransformConfig.java | 46 +++++- .../core/dataframe/transforms/SyncConfig.java | 25 +++ .../dataframe/transforms/TimeSyncConfig.java | 148 ++++++++++++++++++ .../AbstractSerializingDataFrameTestCase.java | 8 + ...tractWireSerializingDataFrameTestCase.java | 8 + ...wDataFrameTransformActionRequestTests.java | 10 +- .../AbstractSerializingDataFrameTestCase.java | 4 + .../DataFrameTransformConfigTests.java | 26 +-- .../transforms/TimeSyncConfigTests.java | 38 +++++ .../integration/DataFrameIntegTestCase.java | 1 + .../DataFrameTransformProgressIT.java | 2 + .../xpack/dataframe/DataFrame.java | 7 + 29 files changed, 705 insertions(+), 39 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java new file mode 100644 index 0000000000000..940b136c93daa --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe; + +import org.elasticsearch.client.dataframe.transforms.SyncConfig; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.Arrays; +import java.util.List; + +public class DataFrameNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry(SyncConfig.class, + new ParseField(TimeSyncConfig.NAME), + TimeSyncConfig::fromXContent)); + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java index 8465ae8342827..da8e5e735d2a0 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import java.io.IOException; import java.util.Objects; @@ -40,12 +41,14 @@ public class DataFrameTransformConfig implements ToXContentObject { public static final ParseField SOURCE = new ParseField("source"); public static final ParseField DEST = new ParseField("dest"); public static final ParseField DESCRIPTION = new ParseField("description"); + public static final ParseField SYNC = new ParseField("sync"); // types of transforms public static final ParseField PIVOT_TRANSFORM = new ParseField("pivot"); private final String id; private final SourceConfig source; private final DestConfig dest; + private final SyncConfig syncConfig; private final PivotConfig pivotConfig; private final String description; @@ -55,19 +58,30 @@ public class DataFrameTransformConfig implements ToXContentObject { String id = (String) args[0]; SourceConfig source = (SourceConfig) args[1]; DestConfig dest = (DestConfig) args[2]; - PivotConfig pivotConfig = (PivotConfig) args[3]; - String description = (String)args[4]; - return new DataFrameTransformConfig(id, source, dest, pivotConfig, description); + SyncConfig syncConfig = (SyncConfig) args[3]; + PivotConfig pivotConfig = (PivotConfig) args[4]; + String description = (String)args[5]; + return new DataFrameTransformConfig(id, source, dest, syncConfig, pivotConfig, description); }); static { PARSER.declareString(constructorArg(), ID); PARSER.declareObject(constructorArg(), (p, c) -> SourceConfig.PARSER.apply(p, null), SOURCE); PARSER.declareObject(constructorArg(), (p, c) -> DestConfig.PARSER.apply(p, null), DEST); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseSyncConfig(p), SYNC); PARSER.declareObject(optionalConstructorArg(), (p, c) -> PivotConfig.fromXContent(p), PIVOT_TRANSFORM); PARSER.declareString(optionalConstructorArg(), DESCRIPTION); } + private static SyncConfig parseSyncConfig(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + SyncConfig syncConfig = parser.namedObject(SyncConfig.class, parser.currentName(), true); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return syncConfig; + } + + public static DataFrameTransformConfig fromXContent(final XContentParser parser) { return PARSER.apply(parser, null); } @@ -84,17 +98,19 @@ public static DataFrameTransformConfig fromXContent(final XContentParser parser) * @return A DataFrameTransformConfig to preview, NOTE it will have a {@code null} id, destination and index. */ public static DataFrameTransformConfig forPreview(final SourceConfig source, final PivotConfig pivotConfig) { - return new DataFrameTransformConfig(null, source, null, pivotConfig, null); + return new DataFrameTransformConfig(null, source, null, null, pivotConfig, null); } DataFrameTransformConfig(final String id, final SourceConfig source, final DestConfig dest, + final SyncConfig syncConfig, final PivotConfig pivotConfig, final String description) { this.id = id; this.source = source; this.dest = dest; + this.syncConfig = syncConfig; this.pivotConfig = pivotConfig; this.description = description; } @@ -111,6 +127,10 @@ public DestConfig getDestination() { return dest; } + public SyncConfig getSyncConfig() { + return syncConfig; + } + public PivotConfig getPivotConfig() { return pivotConfig; } @@ -132,6 +152,11 @@ public XContentBuilder toXContent(final XContentBuilder builder, final Params pa if (dest != null) { builder.field(DEST.getPreferredName(), dest); } + if (syncConfig != null) { + builder.startObject(SYNC.getPreferredName()); + builder.field(syncConfig.getName(), syncConfig); + builder.endObject(); + } if (pivotConfig != null) { builder.field(PIVOT_TRANSFORM.getPreferredName(), pivotConfig); } @@ -158,12 +183,13 @@ public boolean equals(Object other) { && Objects.equals(this.source, that.source) && Objects.equals(this.dest, that.dest) && Objects.equals(this.description, that.description) + && Objects.equals(this.syncConfig, that.syncConfig) && Objects.equals(this.pivotConfig, that.pivotConfig); } @Override public int hashCode() { - return Objects.hash(id, source, dest, pivotConfig, description); + return Objects.hash(id, source, dest, syncConfig, pivotConfig, description); } @Override @@ -180,6 +206,7 @@ public static class Builder { private String id; private SourceConfig source; private DestConfig dest; + private SyncConfig syncConfig; private PivotConfig pivotConfig; private String description; @@ -198,6 +225,11 @@ public Builder setDest(DestConfig dest) { return this; } + public Builder setSyncConfig(SyncConfig syncConfig) { + this.syncConfig = syncConfig; + return this; + } + public Builder setPivotConfig(PivotConfig pivotConfig) { this.pivotConfig = pivotConfig; return this; @@ -209,7 +241,7 @@ public Builder setDescription(String description) { } public DataFrameTransformConfig build() { - return new DataFrameTransformConfig(id, source, dest, pivotConfig, description); + return new DataFrameTransformConfig(id, source, dest, syncConfig, pivotConfig, description); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java new file mode 100644 index 0000000000000..3ead35d0a491a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java @@ -0,0 +1,30 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +public interface SyncConfig extends ToXContentObject { + + /** + * Returns the name of the writeable object + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java new file mode 100644 index 0000000000000..797ca3f896138 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java @@ -0,0 +1,108 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class TimeSyncConfig implements SyncConfig { + + public static final String NAME = "time"; + + private static final ParseField FIELD = new ParseField("field"); + private static final ParseField DELAY = new ParseField("delay"); + + private final String field; + private final TimeValue delay; + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("time_sync_config", true, + args -> new TimeSyncConfig((String) args[0], args[1] != null ? (TimeValue) args[1] : TimeValue.ZERO)); + + static { + PARSER.declareString(constructorArg(), FIELD); + PARSER.declareField(optionalConstructorArg(), (p, c) -> TimeValue.parseTimeValue(p.textOrNull(), DELAY.getPreferredName()), DELAY, + ObjectParser.ValueType.STRING_OR_NULL); + } + + public static TimeSyncConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public TimeSyncConfig(String field, TimeValue delay) { + this.field = field; + this.delay = delay; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD.getPreferredName(), field); + if (delay.duration() > 0) { + builder.field(DELAY.getPreferredName(), delay.getStringRep()); + } + builder.endObject(); + return builder; + } + + public String getField() { + return field; + } + + public TimeValue getDelay() { + return delay; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + final TimeSyncConfig that = (TimeSyncConfig) other; + + return Objects.equals(this.field, that.field) + && Objects.equals(this.delay, that.delay); + } + + @Override + public int hashCode() { + return Objects.hash(field, delay); + } + + @Override + public String getName() { + return NAME; + } + +} diff --git a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider index 342c606a540a6..77f1d9700d9a4 100644 --- a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider +++ b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider @@ -1,2 +1,3 @@ +org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider \ No newline at end of file diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java index 8c6b1c6045855..6e48cc507c3e7 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java @@ -23,6 +23,7 @@ import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpPut; +import org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider; import org.elasticsearch.client.dataframe.DeleteDataFrameTransformRequest; import org.elasticsearch.client.dataframe.GetDataFrameTransformRequest; import org.elasticsearch.client.dataframe.GetDataFrameTransformStatsRequest; @@ -42,6 +43,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import static org.hamcrest.Matchers.equalTo; @@ -50,7 +52,10 @@ public class DataFrameRequestConvertersTests extends ESTestCase { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } public void testPutDataFrameTransform() throws IOException { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index b9dfa4274c28d..6ca2e3c2bdd24 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -46,6 +46,8 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.core.MainRequest; import org.elasticsearch.client.core.MainResponse; +import org.elasticsearch.client.dataframe.transforms.SyncConfig; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; import org.elasticsearch.client.indexlifecycle.AllocateAction; import org.elasticsearch.client.indexlifecycle.DeleteAction; import org.elasticsearch.client.indexlifecycle.ForceMergeAction; @@ -666,7 +668,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(21, namedXContents.size()); + assertEquals(22, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -676,7 +678,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 5, categories.size()); + assertEquals("Had: " + categories, 6, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -702,6 +704,8 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(SetPriorityAction.NAME)); assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class)); assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); + assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); + assertTrue(names.contains(TimeSyncConfig.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java index f7386e936301b..27b2cc9b99bf8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java @@ -79,6 +79,9 @@ private static void toXContent(GetDataFrameTransformResponse response, XContentB @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java index c91e1cbb1dd91..45d5d879d47f9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Optional; import static org.elasticsearch.client.dataframe.transforms.SourceConfigTests.randomSourceConfig; @@ -55,7 +56,10 @@ protected boolean supportsUnknownFields() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } public void testValidate() { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java index 28fd92dcf913f..7c7cd3fa151fe 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Optional; import static org.hamcrest.Matchers.containsString; @@ -71,6 +72,9 @@ protected boolean supportsUnknownFields() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java index 1b5228d96229f..803ad39b9fb63 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.dataframe.transforms; +import org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider; import org.elasticsearch.client.dataframe.transforms.pivot.PivotConfigTests; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -28,6 +29,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.function.Predicate; import static org.elasticsearch.client.dataframe.transforms.DestConfigTests.randomDestConfig; @@ -36,8 +38,13 @@ public class DataFrameTransformConfigTests extends AbstractXContentTestCase { public static DataFrameTransformConfig randomDataFrameTransformConfig() { - return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), - randomDestConfig(), PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100)); + return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), randomDestConfig(), + randomBoolean() ? randomSyncConfig() : null, PivotConfigTests.randomPivotConfig(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100)); + } + + public static SyncConfig randomSyncConfig() { + return TimeSyncConfigTests.randomTimeSyncConfig(); } @Override @@ -64,6 +71,9 @@ protected Predicate getRandomFieldsExcludeFilter() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..dd2a17eb0260d --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class TimeSyncConfigTests extends AbstractXContentTestCase { + + public static TimeSyncConfig randomTimeSyncConfig() { + return new TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), new TimeValue(randomNonNegativeLong())); + } + + @Override + protected TimeSyncConfig createTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected TimeSyncConfig doParseInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..0c6a0350882a4 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java @@ -0,0 +1,59 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms.hlrc; + +import org.elasticsearch.client.AbstractResponseTestCase; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; + +public class TimeSyncConfigTests + extends AbstractResponseTestCase { + + public static org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig randomTimeSyncConfig() { + return new org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), + new TimeValue(randomNonNegativeLong())); + } + + public static void assertHlrcEquals(org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig serverTestInstance, + TimeSyncConfig clientInstance) { + assertEquals(serverTestInstance.getField(), clientInstance.getField()); + assertEquals(serverTestInstance.getDelay(), clientInstance.getDelay()); + } + + @Override + protected org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig createServerTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected TimeSyncConfig doParseToClientInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser); + } + + @Override + protected void assertInstances(org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig serverTestInstance, + TimeSyncConfig clientInstance) { + assertHlrcEquals(serverTestInstance, clientInstance); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java index 60dd2cb32eaab..a41540af3fb4a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java @@ -422,6 +422,7 @@ public void testPreview() throws IOException, InterruptedException { .setQueryConfig(queryConfig) .build(), // <1> pivotConfig); // <2> + PreviewDataFrameTransformRequest request = new PreviewDataFrameTransformRequest(transformConfig); // <3> // end::preview-data-frame-transform-request @@ -469,7 +470,6 @@ public void testGetStats() throws IOException, InterruptedException { RestHighLevelClient client = highLevelClient(); - QueryConfig queryConfig = new QueryConfig(new MatchAllQueryBuilder()); GroupConfig groupConfig = GroupConfig.builder().groupBy("reviewer", TermsGroupSource.builder().setField("user_id").build()).build(); AggregatorFactories.Builder aggBuilder = new AggregatorFactories.Builder(); @@ -554,8 +554,7 @@ public void onFailure(Exception e) { public void testGetDataFrameTransform() throws IOException, InterruptedException { createIndex("source-data"); - - QueryConfig queryConfig = new QueryConfig(new MatchAllQueryBuilder()); + GroupConfig groupConfig = GroupConfig.builder().groupBy("reviewer", TermsGroupSource.builder().setField("user_id").build()).build(); AggregatorFactories.Builder aggBuilder = new AggregatorFactories.Builder(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 74c2b3d6af407..07fac1baa8b58 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -54,6 +54,8 @@ import org.elasticsearch.xpack.core.dataframe.action.StopDataFrameTransformAction; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransform; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformState; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.elasticsearch.xpack.core.deprecation.DeprecationInfoAction; import org.elasticsearch.xpack.core.graph.GraphFeatureSetUsage; import org.elasticsearch.xpack.core.graph.action.GraphExploreAction; @@ -511,7 +513,9 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.DATA_FRAME, DataFrameFeatureSetUsage::new), new NamedWriteableRegistry.Entry(PersistentTaskParams.class, DataFrameField.TASK_NAME, DataFrameTransform::new), new NamedWriteableRegistry.Entry(Task.Status.class, DataFrameField.TASK_NAME, DataFrameTransformState::new), - new NamedWriteableRegistry.Entry(PersistentTaskState.class, DataFrameField.TASK_NAME, DataFrameTransformState::new) + new NamedWriteableRegistry.Entry(PersistentTaskState.class, DataFrameField.TASK_NAME, DataFrameTransformState::new), + new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), TimeSyncConfig::new) + ); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java index 71bf14cdeb4a5..d6dd5a30f5d65 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java @@ -27,6 +27,10 @@ public final class DataFrameField { public static final ParseField SOURCE = new ParseField("source"); public static final ParseField DESTINATION = new ParseField("dest"); public static final ParseField FORCE = new ParseField("force"); + public static final ParseField FIELD = new ParseField("field"); + public static final ParseField SYNC = new ParseField("sync"); + public static final ParseField TIME_BASED_SYNC = new ParseField("time"); + public static final ParseField DELAY = new ParseField("delay"); /** * Fields for checkpointing diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java new file mode 100644 index 0000000000000..9eacfc5ff1eae --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe; + +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; + +import java.util.Arrays; +import java.util.List; + +public class DataFrameNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry(SyncConfig.class, + DataFrameField.TIME_BASED_SYNC, + TimeSyncConfig::parse)); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java index ee35fe3d21ec7..2750daea6b3fd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xpack.core.dataframe.DataFrameField; import org.elasticsearch.xpack.core.dataframe.DataFrameMessages; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig; @@ -49,6 +50,7 @@ public class DataFrameTransformConfig extends AbstractDiffable create SourceConfig source = (SourceConfig) args[1]; DestConfig dest = (DestConfig) args[2]; - // ignored, only for internal storage: String docType = (String) args[3]; + SyncConfig syncConfig = (SyncConfig) args[3]; + // ignored, only for internal storage: String docType = (String) args[4]; // on strict parsing do not allow injection of headers - if (lenient == false && args[4] != null) { + if (lenient == false && args[5] != null) { throw new IllegalArgumentException("Found [headers], not allowed for strict parsing"); } @SuppressWarnings("unchecked") - Map headers = (Map) args[4]; + Map headers = (Map) args[5]; - PivotConfig pivotConfig = (PivotConfig) args[5]; - String description = (String)args[6]; - return new DataFrameTransformConfig(id, source, dest, headers, pivotConfig, description); + PivotConfig pivotConfig = (PivotConfig) args[6]; + String description = (String)args[7]; + return new DataFrameTransformConfig(id, source, dest, syncConfig, headers, pivotConfig, description); }); parser.declareString(optionalConstructorArg(), DataFrameField.ID); parser.declareObject(constructorArg(), (p, c) -> SourceConfig.fromXContent(p, lenient), DataFrameField.SOURCE); parser.declareObject(constructorArg(), (p, c) -> DestConfig.fromXContent(p, lenient), DataFrameField.DESTINATION); + parser.declareObject(optionalConstructorArg(), (p, c) -> parseSyncConfig(p, lenient), DataFrameField.SYNC); + parser.declareString(optionalConstructorArg(), DataFrameField.INDEX_DOC_TYPE); + parser.declareObject(optionalConstructorArg(), (p, c) -> p.mapStrings(), HEADERS); parser.declareObject(optionalConstructorArg(), (p, c) -> PivotConfig.fromXContent(p, lenient), PIVOT_TRANSFORM); parser.declareString(optionalConstructorArg(), DESCRIPTION); @@ -99,6 +105,14 @@ private static ConstructingObjectParser create return parser; } + private static SyncConfig parseSyncConfig(XContentParser parser, boolean ignoreUnknownFields) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + SyncConfig syncConfig = parser.namedObject(SyncConfig.class, parser.currentName(), ignoreUnknownFields); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return syncConfig; + } + public static String documentId(String transformId) { return NAME + "-" + transformId; } @@ -106,12 +120,14 @@ public static String documentId(String transformId) { public DataFrameTransformConfig(final String id, final SourceConfig source, final DestConfig dest, + final SyncConfig syncConfig, final Map headers, final PivotConfig pivotConfig, final String description) { this.id = ExceptionsHelper.requireNonNull(id, DataFrameField.ID.getPreferredName()); this.source = ExceptionsHelper.requireNonNull(source, DataFrameField.SOURCE.getPreferredName()); this.dest = ExceptionsHelper.requireNonNull(dest, DataFrameField.DESTINATION.getPreferredName()); + this.syncConfig = syncConfig; this.setHeaders(headers == null ? Collections.emptyMap() : headers); this.pivotConfig = pivotConfig; this.description = description; @@ -129,6 +145,7 @@ public DataFrameTransformConfig(final StreamInput in) throws IOException { id = in.readString(); source = new SourceConfig(in); dest = new DestConfig(in); + syncConfig = in.readOptionalNamedWriteable(SyncConfig.class); setHeaders(in.readMap(StreamInput::readString, StreamInput::readString)); pivotConfig = in.readOptionalWriteable(PivotConfig::new); description = in.readOptionalString(); @@ -146,6 +163,10 @@ public DestConfig getDestination() { return dest; } + public SyncConfig getSyncConfig() { + return syncConfig; + } + public Map getHeaders() { return headers; } @@ -168,6 +189,10 @@ public boolean isValid() { return false; } + if (syncConfig != null && syncConfig.isValid() == false) { + return false; + } + return source.isValid() && dest.isValid(); } @@ -176,6 +201,7 @@ public void writeTo(final StreamOutput out) throws IOException { out.writeString(id); source.writeTo(out); dest.writeTo(out); + out.writeOptionalNamedWriteable(syncConfig); out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); out.writeOptionalWriteable(pivotConfig); out.writeOptionalString(description); @@ -187,6 +213,11 @@ public XContentBuilder toXContent(final XContentBuilder builder, final Params pa builder.field(DataFrameField.ID.getPreferredName(), id); builder.field(DataFrameField.SOURCE.getPreferredName(), source); builder.field(DataFrameField.DESTINATION.getPreferredName(), dest); + if (syncConfig != null) { + builder.startObject(DataFrameField.SYNC.getPreferredName()); + builder.field(syncConfig.getWriteableName(), syncConfig); + builder.endObject(); + } if (pivotConfig != null) { builder.field(PIVOT_TRANSFORM.getPreferredName(), pivotConfig); } @@ -218,6 +249,7 @@ public boolean equals(Object other) { return Objects.equals(this.id, that.id) && Objects.equals(this.source, that.source) && Objects.equals(this.dest, that.dest) + && Objects.equals(this.syncConfig, that.syncConfig) && Objects.equals(this.headers, that.headers) && Objects.equals(this.pivotConfig, that.pivotConfig) && Objects.equals(this.description, that.description); @@ -225,7 +257,7 @@ public boolean equals(Object other) { @Override public int hashCode(){ - return Objects.hash(id, source, dest, headers, pivotConfig, description); + return Objects.hash(id, source, dest, syncConfig, headers, pivotConfig, description); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java new file mode 100644 index 0000000000000..d8008f12126e3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.index.query.QueryBuilder; + +public interface SyncConfig extends ToXContentObject, NamedWriteable { + + /** + * Validate configuration + * + * @return true if valid + */ + boolean isValid(); + + QueryBuilder getBoundaryQuery(DataFrameTransformCheckpoint checkpoint); + + QueryBuilder getChangesQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java new file mode 100644 index 0000000000000..9a949ab21e8c3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.RangeQueryBuilder; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class TimeSyncConfig implements SyncConfig { + + private static final String NAME = "data_frame_transform_pivot_sync_time"; + + private final String field; + private final TimeValue delay; + + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, lenient, + args -> { + String field = (String) args[0]; + TimeValue delay = args[1] != null ? (TimeValue) args[1] : TimeValue.ZERO; + + return new TimeSyncConfig(field, delay); + }); + + parser.declareString(constructorArg(), DataFrameField.FIELD); + parser.declareField(optionalConstructorArg(), + (p, c) -> TimeValue.parseTimeValue(p.textOrNull(), DataFrameField.DELAY.getPreferredName()), DataFrameField.DELAY, + ObjectParser.ValueType.STRING_OR_NULL); + + return parser; + } + + public TimeSyncConfig() { + this(null, null); + } + + public TimeSyncConfig(final String field, final TimeValue delay) { + this.field = ExceptionsHelper.requireNonNull(field, DataFrameField.FIELD.getPreferredName()); + this.delay = ExceptionsHelper.requireNonNull(delay, DataFrameField.DELAY.getPreferredName()); + } + + public TimeSyncConfig(StreamInput in) throws IOException { + this.field = in.readString(); + this.delay = in.readTimeValue(); + } + + public String getField() { + return field; + } + + public TimeValue getDelay() { + return delay; + } + + @Override + public boolean isValid() { + return true; + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + out.writeString(field); + out.writeTimeValue(delay); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + builder.field(DataFrameField.FIELD.getPreferredName(), field); + if (delay.duration() > 0) { + builder.field(DataFrameField.DELAY.getPreferredName(), delay.getStringRep()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + final TimeSyncConfig that = (TimeSyncConfig) other; + + return Objects.equals(this.field, that.field) + && Objects.equals(this.delay, that.delay); + } + + @Override + public int hashCode(){ + return Objects.hash(field, delay); + } + + @Override + public String toString() { + return Strings.toString(this, true, true); + } + + public static TimeSyncConfig parse(final XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + public static TimeSyncConfig fromXContent(final XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + @Override + public String getWriteableName() { + return DataFrameField.TIME_BASED_SYNC.getPreferredName(); + } + + @Override + public QueryBuilder getBoundaryQuery(DataFrameTransformCheckpoint checkpoint) { + return new RangeQueryBuilder(field).lt(checkpoint.getTimeUpperBound()).format("epoch_millis"); + } + + @Override + public QueryBuilder getChangesQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint) { + return new RangeQueryBuilder(field).gte(oldCheckpoint.getTimeUpperBound()).lt(newCheckpoint.getTimeUpperBound()) + .format("epoch_millis"); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java index 8b633cdfc26d5..14cbdef148ca4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java @@ -13,6 +13,10 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.junit.Before; import java.util.List; @@ -30,7 +34,11 @@ public void registerNamedObjects() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); List namedWriteables = searchModule.getNamedWriteables(); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java index 91a7ec54dd256..47d7860b71da0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java @@ -12,6 +12,10 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.junit.Before; import java.util.List; @@ -30,7 +34,11 @@ public void registerNamedObjects() { SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); List namedWriteables = searchModule.getNamedWriteables(); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java index 0cfc659e50646..0e0187d8ef02b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.dataframe.action.PreviewDataFrameTransformAction.Request; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfigTests; import org.elasticsearch.xpack.core.dataframe.transforms.DestConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfigTests; @@ -39,9 +40,14 @@ protected boolean supportsUnknownFields() { @Override protected Request createTestInstance() { - DataFrameTransformConfig config = new DataFrameTransformConfig("transform-preview", randomSourceConfig(), + DataFrameTransformConfig config = new DataFrameTransformConfig( + "transform-preview", + randomSourceConfig(), new DestConfig("unused-transform-preview-index"), - null, PivotConfigTests.randomPivotConfig(), null); + randomBoolean() ? DataFrameTransformConfigTests.randomSyncConfig() : null, + null, + PivotConfigTests.randomPivotConfig(), + null); return new Request(config); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java index 2b64fadac051a..79edb8084551d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.BaseAggregationBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; import org.junit.Before; import java.util.Collections; @@ -48,12 +49,15 @@ public void registerAggregationNamedObjects() throws Exception { MockDeprecatedQueryBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(AggregationBuilder.class, MockDeprecatedAggregationBuilder.NAME, MockDeprecatedAggregationBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); List namedXContents = searchModule.getNamedXContents(); namedXContents.add(new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(MockDeprecatedQueryBuilder.NAME), (p, c) -> MockDeprecatedQueryBuilder.fromXContent(p))); namedXContents.add(new NamedXContentRegistry.Entry(BaseAggregationBuilder.class, new ParseField(MockDeprecatedAggregationBuilder.NAME), (p, c) -> MockDeprecatedAggregationBuilder.fromXContent(p))); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java index a735b5a02acb8..8b46b8cd2838a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java @@ -41,24 +41,28 @@ public static DataFrameTransformConfig randomDataFrameTransformConfig() { } public static DataFrameTransformConfig randomDataFrameTransformConfigWithoutHeaders(String id) { - return new DataFrameTransformConfig(id, randomSourceConfig(), randomDestConfig(), null, + return new DataFrameTransformConfig(id, randomSourceConfig(), randomDestConfig(), randomBoolean() ? randomSyncConfig() : null, null, PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); } public static DataFrameTransformConfig randomDataFrameTransformConfig(String id) { - return new DataFrameTransformConfig(id, randomSourceConfig(), randomDestConfig(), randomHeaders(), - PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); + return new DataFrameTransformConfig(id, randomSourceConfig(), randomDestConfig(), randomBoolean() ? randomSyncConfig() : null, + randomHeaders(), PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); } public static DataFrameTransformConfig randomInvalidDataFrameTransformConfig() { if (randomBoolean()) { - return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomInvalidSourceConfig(), - randomDestConfig(), randomHeaders(), PivotConfigTests.randomPivotConfig(), - randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100)); + return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomInvalidSourceConfig(), randomDestConfig(), + randomBoolean() ? randomSyncConfig() : null, randomHeaders(), PivotConfigTests.randomPivotConfig(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); } // else - return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), - randomDestConfig(), randomHeaders(), PivotConfigTests.randomInvalidPivotConfig(), - randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100)); + return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), randomDestConfig(), + randomBoolean() ? randomSyncConfig() : null, randomHeaders(), PivotConfigTests.randomInvalidPivotConfig(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); + } + + public static SyncConfig randomSyncConfig() { + return TimeSyncConfigTests.randomTimeSyncConfig(); } @Before @@ -167,11 +171,11 @@ public void testXContentForInternalStorage() throws IOException { public void testMaxLengthDescription() { IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new DataFrameTransformConfig("id", - randomSourceConfig(), randomDestConfig(), null, PivotConfigTests.randomPivotConfig(), randomAlphaOfLength(1001))); + randomSourceConfig(), randomDestConfig(), null, null, PivotConfigTests.randomPivotConfig(), randomAlphaOfLength(1001))); assertThat(exception.getMessage(), equalTo("[description] must be less than 1000 characters in length.")); String description = randomAlphaOfLength(1000); DataFrameTransformConfig config = new DataFrameTransformConfig("id", - randomSourceConfig(), randomDestConfig(), null, PivotConfigTests.randomPivotConfig(), description); + randomSourceConfig(), randomDestConfig(), null, null, PivotConfigTests.randomPivotConfig(), description); assertThat(description, equalTo(config.getDescription())); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..763e13e77aee0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.io.stream.Writeable.Reader; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; + +import java.io.IOException; + +public class TimeSyncConfigTests extends AbstractSerializingTestCase { + + public static TimeSyncConfig randomTimeSyncConfig() { + return new TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), new TimeValue(randomNonNegativeLong())); + } + + @Override + protected TimeSyncConfig doParseInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser, false); + } + + @Override + protected TimeSyncConfig createTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected Reader instanceReader() { + return TimeSyncConfig::new; + } + +} diff --git a/x-pack/plugin/data-frame/qa/multi-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameIntegTestCase.java b/x-pack/plugin/data-frame/qa/multi-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameIntegTestCase.java index 84f3e05de5cd1..0eeaf6d405298 100644 --- a/x-pack/plugin/data-frame/qa/multi-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameIntegTestCase.java +++ b/x-pack/plugin/data-frame/qa/multi-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameIntegTestCase.java @@ -192,6 +192,7 @@ protected DataFrameTransformConfig createTransformConfig(String id, return new DataFrameTransformConfig(id, new SourceConfig(sourceIndices, createQueryConfig(queryBuilder)), new DestConfig(destinationIndex), + null, Collections.emptyMap(), createPivotConfig(groups, aggregations), "Test data frame transform config id: " + id); diff --git a/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java b/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java index d338d6949f07b..ec48c10ee6c9a 100644 --- a/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java +++ b/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java @@ -135,6 +135,7 @@ public void testGetProgress() throws Exception { sourceConfig, destConfig, null, + null, pivotConfig, null); @@ -155,6 +156,7 @@ public void testGetProgress() throws Exception { sourceConfig, destConfig, null, + null, pivotConfig, null); diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/DataFrame.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/DataFrame.java index b7e6c235f8e6c..75c8c7443bd72 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/DataFrame.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/DataFrame.java @@ -24,6 +24,7 @@ import org.elasticsearch.common.settings.SettingsFilter; import org.elasticsearch.common.settings.SettingsModule; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry.Entry; import org.elasticsearch.env.Environment; import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.license.XPackLicenseState; @@ -40,6 +41,7 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; import org.elasticsearch.xpack.core.dataframe.action.DeleteDataFrameTransformAction; import org.elasticsearch.xpack.core.dataframe.action.GetDataFrameTransformsAction; import org.elasticsearch.xpack.core.dataframe.action.GetDataFrameTransformsStatsAction; @@ -231,4 +233,9 @@ public void close() { schedulerEngine.get().stop(); } } + + @Override + public List getNamedXContent() { + return new DataFrameNamedXContentProvider().getNamedXContentParsers(); + } } From 2854bbc0cb9e41ebad6acf51976180574011ec3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Tue, 7 May 2019 09:10:55 +0200 Subject: [PATCH 44/67] [ML] Refine client code for data frame analytics (#41665) Make the client code more usable and uniform: * Introduce Builder classes * Use Strings.toString consistently * Remove unused copy constructors --- .../ml/dataframe/DataFrameAnalysis.java | 4 - .../dataframe/DataFrameAnalyticsConfig.java | 32 +------ .../ml/dataframe/DataFrameAnalyticsDest.java | 71 ++++++++------ .../dataframe/DataFrameAnalyticsSource.java | 65 +++++++------ .../ml/dataframe/DataFrameAnalyticsStats.java | 3 +- ...ataFrameAnalysisNamedXContentProvider.java | 14 ++- .../client/ml/dataframe/OutlierDetection.java | 95 ++++++++++--------- .../client/ml/dataframe/QueryConfig.java | 10 +- .../client/MachineLearningIT.java | 73 +++++++++----- .../DataFrameAnalyticsDestTests.java | 5 +- .../DataFrameAnalyticsSourceTests.java | 7 +- .../ml/dataframe/OutlierDetectionTests.java | 28 +++--- 12 files changed, 211 insertions(+), 196 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java index 585d135700aa4..81b19eefce573 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java @@ -21,11 +21,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; -import java.util.Map; - public interface DataFrameAnalysis extends ToXContentObject { String getName(); - - Map getParams(); } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java index a5ede0e9128be..b1309e66afcd4 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java @@ -21,7 +21,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -43,11 +43,9 @@ public static DataFrameAnalyticsConfig fromXContent(XContentParser parser) { } public static Builder builder(String id) { - return new Builder(id); + return new Builder().setId(id); } - private static final String NAME = "data_frame_analytics_config"; - private static final ParseField ID = new ParseField("id"); private static final ParseField SOURCE = new ParseField("source"); private static final ParseField DEST = new ParseField("dest"); @@ -55,7 +53,7 @@ public static Builder builder(String id) { private static final ParseField ANALYZED_FIELDS = new ParseField("analyzed_fields"); private static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); - private static ObjectParser PARSER = new ObjectParser<>(NAME, true, Builder::new); + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_config", true, Builder::new); static { PARSER.declareString(Builder::setId, ID); @@ -159,14 +157,7 @@ public int hashCode() { @Override public String toString() { - return new ToStringBuilder(getClass()) - .add("id", id) - .add("source", source) - .add("dest", dest) - .add("analysis", analysis) - .add("analyzedFields", analyzedFields) - .add("modelMemoryLimit", modelMemoryLimit) - .toString(); + return Strings.toString(this); } public static class Builder { @@ -180,21 +171,6 @@ public static class Builder { private Builder() {} - private Builder(String id) { - setId(id); - } - - public Builder(DataFrameAnalyticsConfig config) { - this.id = config.id; - this.source = new DataFrameAnalyticsSource(config.source); - this.dest = new DataFrameAnalyticsDest(config.dest); - this.analysis = config.analysis; - if (config.analyzedFields != null) { - this.analyzedFields = new FetchSourceContext(true, config.analyzedFields.includes(), config.analyzedFields.excludes()); - } - this.modelMemoryLimit = config.modelMemoryLimit; - } - public Builder setId(String id) { this.id = Objects.requireNonNull(id); return this; diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java index c15ca05c969b0..4123f85ee2f43 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java @@ -21,8 +21,8 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.inject.internal.ToStringBuilder; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -35,39 +35,37 @@ public class DataFrameAnalyticsDest implements ToXContentObject { public static DataFrameAnalyticsDest fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + return PARSER.apply(parser, null).build(); + } + + public static Builder builder() { + return new Builder(); } private static final ParseField INDEX = new ParseField("index"); private static final ParseField RESULTS_FIELD = new ParseField("results_field"); - private static ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("data_frame_analytics_dest", true, - (args) -> { - String index = (String) args[0]; - String resultsField = (String) args[1]; - return new DataFrameAnalyticsDest(index, resultsField); - }); + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_dest", true, Builder::new); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD); + PARSER.declareString(Builder::setIndex, INDEX); + PARSER.declareString(Builder::setResultsField, RESULTS_FIELD); } private final String index; private final String resultsField; - public DataFrameAnalyticsDest(String index) { - this(index, null); - } - - public DataFrameAnalyticsDest(String index, @Nullable String resultsField) { + private DataFrameAnalyticsDest(String index, @Nullable String resultsField) { this.index = requireNonNull(index); this.resultsField = resultsField; } - public DataFrameAnalyticsDest(DataFrameAnalyticsDest other) { - this(other.index, other.resultsField); + public String getIndex() { + return index; + } + + public String getResultsField() { + return resultsField; } @Override @@ -91,24 +89,35 @@ public boolean equals(Object o) { && Objects.equals(resultsField, other.resultsField); } - @Override - public String toString() { - return new ToStringBuilder(getClass()) - .add("index", index) - .add("resultsField", resultsField) - .toString(); - } - @Override public int hashCode() { return Objects.hash(index, resultsField); } - public String getIndex() { - return index; + @Override + public String toString() { + return Strings.toString(this); } - public String getResultsField() { - return resultsField; + public static class Builder { + + private String index; + private String resultsField; + + private Builder() {} + + public Builder setIndex(String index) { + this.index = index; + return this; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public DataFrameAnalyticsDest build() { + return new DataFrameAnalyticsDest(index, resultsField); + } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java index f18f65ad3f66c..c36799cd3b4a7 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java @@ -21,8 +21,8 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.inject.internal.ToStringBuilder; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -33,39 +33,37 @@ public class DataFrameAnalyticsSource implements ToXContentObject { public static DataFrameAnalyticsSource fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + return PARSER.apply(parser, null).build(); + } + + public static Builder builder() { + return new Builder(); } private static final ParseField INDEX = new ParseField("index"); private static final ParseField QUERY = new ParseField("query"); - private static ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("data_frame_analytics_source", true, - (args) -> { - String index = (String) args[0]; - QueryConfig queryConfig = (QueryConfig) args[1]; - return new DataFrameAnalyticsSource(index, queryConfig); - }); + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_source", true, Builder::new); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX); - PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), (p, c) -> QueryConfig.fromXContent(p), QUERY); + PARSER.declareString(Builder::setIndex, INDEX); + PARSER.declareObject(Builder::setQueryConfig, (p, c) -> QueryConfig.fromXContent(p), QUERY); } private final String index; private final QueryConfig queryConfig; - public DataFrameAnalyticsSource(String index) { - this(index, null); - } - - public DataFrameAnalyticsSource(String index, @Nullable QueryConfig queryConfig) { + private DataFrameAnalyticsSource(String index, @Nullable QueryConfig queryConfig) { this.index = Objects.requireNonNull(index); this.queryConfig = queryConfig; } - public DataFrameAnalyticsSource(DataFrameAnalyticsSource other) { - this(other.index, new QueryConfig(other.queryConfig)); + public String getIndex() { + return index; + } + + public QueryConfig getQueryConfig() { + return queryConfig; } @Override @@ -96,17 +94,28 @@ public int hashCode() { @Override public String toString() { - return new ToStringBuilder(getClass()) - .add("index", index) - .add("queryConfig", queryConfig) - .toString(); + return Strings.toString(this); } - public String getIndex() { - return index; - } + public static class Builder { - public QueryConfig getQueryConfig() { - return queryConfig; + private String index; + private QueryConfig queryConfig; + + private Builder() {} + + public Builder setIndex(String index) { + this.index = index; + return this; + } + + public Builder setQueryConfig(QueryConfig queryConfig) { + this.queryConfig = queryConfig; + return this; + } + + public DataFrameAnalyticsSource build() { + return new DataFrameAnalyticsSource(index, queryConfig); + } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java index ef9dced4263cc..5c652f33edb2e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -46,8 +46,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation"); private static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>( - "data_frame_analytics_stats", true, + new ConstructingObjectParser<>("data_frame_analytics_stats", true, args -> new DataFrameAnalyticsStats( (String) args[0], (DataFrameAnalyticsState) args[1], diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java index 3b3a28eb3a8b5..3b78c60be91fd 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java @@ -21,19 +21,17 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider { @Override public List getNamedXContentParsers() { - List namedXContent = new ArrayList<>(); - - namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> { - return OutlierDetection.fromXContent(p); - })); - - return namedXContent; + return Arrays.asList( + new NamedXContentRegistry.Entry( + DataFrameAnalysis.class, + OutlierDetection.NAME, + (p, c) -> OutlierDetection.fromXContent(p))); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java index be5334c14d518..bb0ecff6865ed 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java @@ -21,42 +21,38 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.inject.internal.ToStringBuilder; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; -import java.util.HashMap; import java.util.Locale; -import java.util.Map; import java.util.Objects; public class OutlierDetection implements DataFrameAnalysis { public static OutlierDetection fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + return PARSER.apply(parser, null).build(); + } + + public static OutlierDetection createDefault() { + return builder().build(); + } + + public static Builder builder() { + return new Builder(); } public static final ParseField NAME = new ParseField("outlier_detection"); static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); static final ParseField METHOD = new ParseField("method"); - private static ConstructingObjectParser PARSER = - new ConstructingObjectParser<>(NAME.getPreferredName(), true, - (args) -> { - Integer nNeighbors = (Integer) args[0]; - Method method = (Method) args[1]; - return new OutlierDetection(nNeighbors, method); - }); - - private final Integer nNeighbors; - private final Method method; + private static ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Builder::new); static { - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), N_NEIGHBORS); - PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { + PARSER.declareInt(Builder::setNNeighbors, N_NEIGHBORS); + PARSER.declareField(Builder::setMethod, p -> { if (p.currentToken() == XContentParser.Token.VALUE_STRING) { return Method.fromString(p.text()); } @@ -64,12 +60,15 @@ public static OutlierDetection fromXContent(XContentParser parser) { }, METHOD, ObjectParser.ValueType.STRING); } + private final Integer nNeighbors; + private final Method method; + /** * Constructs the outlier detection configuration * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. * @param method The method. Leave unspecified for a dynamic mixture of methods. */ - public OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method) { + private OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method) { if (nNeighbors != null && nNeighbors <= 0) { throw new IllegalArgumentException("[" + N_NEIGHBORS.getPreferredName() + "] must be a positive integer"); } @@ -78,11 +77,17 @@ public OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method) { this.method = method; } - /** - * Constructs the default outlier detection configuration - */ - public OutlierDetection() { - this(null, null); + @Override + public String getName() { + return NAME.getPreferredName(); + } + + public Integer getNNeighbors() { + return nNeighbors; + } + + public Method getMethod() { + return method; } @Override @@ -115,27 +120,7 @@ public int hashCode() { @Override public String toString() { - return new ToStringBuilder(getClass()) - .add("nNeighbors", nNeighbors) - .add("method", method) - .toString(); - } - - @Override - public String getName() { - return NAME.getPreferredName(); - } - - @Override - public Map getParams() { - Map params = new HashMap<>(); - if (nNeighbors != null) { - params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); - } - if (method != null) { - params.put(METHOD.getPreferredName(), method); - } - return params; + return Strings.toString(this); } public enum Method { @@ -150,4 +135,26 @@ public String toString() { return name().toLowerCase(Locale.ROOT); } } + + public static class Builder { + + private Integer nNeighbors; + private Method method; + + private Builder() {} + + public Builder setNNeighbors(Integer nNeighbors) { + this.nNeighbors = nNeighbors; + return this; + } + + public Builder setMethod(Method method) { + this.method = method; + return this; + } + + public OutlierDetection build() { + return new OutlierDetection(nNeighbors, method); + } + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java index f3694dd50cb54..ae704db9f800e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java @@ -19,7 +19,7 @@ package org.elasticsearch.client.ml.dataframe; -import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; @@ -77,12 +77,6 @@ public int hashCode() { @Override public String toString() { - return new ToStringBuilder(getClass()) - .add("query", query) - .toString(); - } - - public boolean isValid() { - return query != null; + return Strings.toString(this); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index fd9f3464bd34e..3b15cf08477e8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1194,11 +1194,15 @@ public void testDeleteCalendarEvent() throws IOException { public void testPutDataFrameAnalyticsConfig() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); - String configId = "test-config"; + String configId = "put-test-config"; DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) - .setSource(new DataFrameAnalyticsSource("test-source-index")) - .setDest(new DataFrameAnalyticsDest("test-dest-index")) - .setAnalysis(new OutlierDetection()) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("put-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("put-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) .build(); PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( @@ -1217,11 +1221,15 @@ public void testPutDataFrameAnalyticsConfig() throws Exception { public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); - String configId = "test-config"; + String configId = "get-test-config"; DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) - .setSource(new DataFrameAnalyticsSource("test-source-index")) - .setDest(new DataFrameAnalyticsDest("test-dest-index")) - .setAnalysis(new OutlierDetection()) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("get-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("get-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) .build(); PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( @@ -1238,17 +1246,20 @@ public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { public void testGetDataFrameAnalyticsConfig_MultipleConfigs() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); - String configIdPrefix = "test-config-"; + String configIdPrefix = "get-test-config-"; int numberOfConfigs = 10; List createdConfigs = new ArrayList<>(); for (int i = 0; i < numberOfConfigs; ++i) { String configId = configIdPrefix + i; - DataFrameAnalyticsConfig config = - DataFrameAnalyticsConfig.builder(configId) - .setSource(new DataFrameAnalyticsSource("index-source-test")) - .setDest(new DataFrameAnalyticsDest("index-dest-test")) - .setAnalysis(new OutlierDetection()) - .build(); + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("get-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("get-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( new PutDataFrameAnalyticsRequest(config), @@ -1310,9 +1321,13 @@ public void testGetDataFrameAnalyticsStats() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String configId = "get-stats-test-config"; DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) - .setSource(new DataFrameAnalyticsSource(sourceIndex)) - .setDest(new DataFrameAnalyticsDest(destIndex)) - .setAnalysis(new OutlierDetection()) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) .build(); execute( @@ -1346,9 +1361,13 @@ public void testStartDataFrameAnalyticsConfig() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String configId = "start-test-config"; DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) - .setSource(new DataFrameAnalyticsSource(sourceIndex)) - .setDest(new DataFrameAnalyticsDest(destIndex)) - .setAnalysis(new OutlierDetection()) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) .build(); execute( @@ -1380,11 +1399,15 @@ private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOExce public void testDeleteDataFrameAnalyticsConfig() throws Exception { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); - String configId = "test-config"; + String configId = "delete-test-config"; DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) - .setSource(new DataFrameAnalyticsSource("test-source-index")) - .setDest(new DataFrameAnalyticsDest("test-dest-index")) - .setAnalysis(new OutlierDetection()) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("delete-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("delete-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) .build(); GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java index 8e208cfbc7f99..dce7ca5204d57 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -27,7 +27,10 @@ public class DataFrameAnalyticsDestTests extends AbstractXContentTestCase { public static DataFrameAnalyticsDest randomDestConfig() { - return new DataFrameAnalyticsDest(randomAlphaOfLengthBetween(1, 10), randomAlphaOfLengthBetween(1, 10)); + return DataFrameAnalyticsDest.builder() + .setIndex(randomAlphaOfLengthBetween(1, 10)) + .setResultsField(randomBoolean() ? null : randomAlphaOfLengthBetween(1, 10)) + .build(); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java index 0898afb5b7781..eb254fd23de09 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java @@ -35,9 +35,10 @@ public class DataFrameAnalyticsSourceTests extends AbstractXContentTestCase { public static DataFrameAnalyticsSource randomSourceConfig() { - return new DataFrameAnalyticsSource( - randomAlphaOfLengthBetween(1, 10), - randomBoolean() ? null : randomQueryConfig()); + return DataFrameAnalyticsSource.builder() + .setIndex(randomAlphaOfLengthBetween(1, 10)) + .setQueryConfig(randomBoolean() ? null : randomQueryConfig()) + .build(); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java index 96a8f7126b08b..9eda15b04e813 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java @@ -23,17 +23,16 @@ import org.elasticsearch.test.AbstractXContentTestCase; import java.io.IOException; -import java.util.Map; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.is; public class OutlierDetectionTests extends AbstractXContentTestCase { public static OutlierDetection randomOutlierDetection() { - return new OutlierDetection( - randomBoolean() ? null : randomIntBetween(1, 20), - randomBoolean() ? null : randomFrom(OutlierDetection.Method.values())); + return OutlierDetection.builder() + .setNNeighbors(randomBoolean() ? null : randomIntBetween(1, 20)) + .setMethod(randomBoolean() ? null : randomFrom(OutlierDetection.Method.values())) + .build(); } @Override @@ -52,17 +51,18 @@ protected OutlierDetection createTestInstance() { } public void testGetParams_GivenDefaults() { - OutlierDetection outlierDetection = new OutlierDetection(); - assertThat(outlierDetection.getParams().isEmpty(), is(true)); + OutlierDetection outlierDetection = OutlierDetection.createDefault(); + assertNull(outlierDetection.getNNeighbors()); + assertNull(outlierDetection.getMethod()); } public void testGetParams_GivenExplicitValues() { - OutlierDetection outlierDetection = new OutlierDetection(42, OutlierDetection.Method.LDOF); - - Map params = outlierDetection.getParams(); - - assertThat(params.size(), equalTo(2)); - assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42)); - assertThat(params.get(OutlierDetection.METHOD.getPreferredName()), equalTo(OutlierDetection.Method.LDOF)); + OutlierDetection outlierDetection = + OutlierDetection.builder() + .setNNeighbors(42) + .setMethod(OutlierDetection.Method.LDOF) + .build(); + assertThat(outlierDetection.getNNeighbors(), equalTo(42)); + assertThat(outlierDetection.getMethod(), equalTo(OutlierDetection.Method.LDOF)); } } From dce6f505c56718d19f6da313a08ec01b9eb03f7c Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 8 May 2019 14:04:28 +0300 Subject: [PATCH 45/67] [ML] Import PageParams correctly --- .../client/ml/GetDataFrameAnalyticsStatsRequest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java index 044d62f229fe5..84bef6894213e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java @@ -21,7 +21,7 @@ import org.elasticsearch.client.Validatable; import org.elasticsearch.client.ValidationException; -import org.elasticsearch.client.ml.job.util.PageParams; +import org.elasticsearch.client.core.PageParams; import org.elasticsearch.common.Nullable; import java.util.Arrays; From 0192c84f024c1f05a08ae6d9beea6a0e2a64151d Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 8 May 2019 21:25:34 +0300 Subject: [PATCH 46/67] [FEATURE][ML] Remove evaluation result abstraction (#41937) This removes the `EvaluationResult` interface and adds a list of metrics directly into the response object. This abstraction makes the HLRC code much more complicated and seems unnecessary. All evaluations should have metrics. Any additional metadata we might add shall be common and can go in the response object. --- .../xpack/core/XPackClientPlugin.java | 3 - .../ml/action/EvaluateDataFrameAction.java | 26 ++++++--- .../ml/dataframe/evaluation/Evaluation.java | 6 +- .../evaluation/EvaluationResult.java | 20 ------- .../MetricListEvaluationResult.java | 58 ------------------- .../MlEvaluationNamedXContentProvider.java | 4 -- .../BinarySoftClassification.java | 6 +- .../TransportEvaluateDataFrameAction.java | 10 ++-- 8 files changed, 29 insertions(+), 104 deletions(-) delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationResult.java delete mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MetricListEvaluationResult.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 07fac1baa8b58..d2f51a2b11ca0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -146,8 +146,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationResult; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MetricListEvaluationResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; @@ -443,7 +441,6 @@ public List getNamedWriteables() { // ML - Data frame evaluation new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new), - new NamedWriteableRegistry.Entry(EvaluationResult.class, MetricListEvaluationResult.NAME, MetricListEvaluationResult::new), new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new), new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), Precision::new), new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), Recall::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java index 0b51d097532d3..eec58428d55cd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java @@ -21,7 +21,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -152,38 +152,46 @@ static class RequestBuilder extends ActionRequestBuilder { public static class Response extends ActionResponse implements ToXContentObject { - private EvaluationResult result; + private String evaluationName; + private List metrics; public Response() { } - public Response(EvaluationResult result) { - this.result = Objects.requireNonNull(result); + public Response(String evaluationName, List metrics) { + this.evaluationName = Objects.requireNonNull(evaluationName); + this.metrics = Objects.requireNonNull(metrics); } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); - this.result = in.readNamedWriteable(EvaluationResult.class); + this.evaluationName = in.readString(); + this.metrics = in.readNamedWriteableList(EvaluationMetricResult.class); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeNamedWriteable(result); + out.writeString(evaluationName); + out.writeList(metrics); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(result.getEvaluationName(), result); + builder.startObject(evaluationName); + for (EvaluationMetricResult metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); builder.endObject(); return builder; } @Override public int hashCode() { - return Objects.hash(result); + return Objects.hash(evaluationName, metrics); } @Override @@ -195,7 +203,7 @@ public boolean equals(Object obj) { return false; } Response other = (Response) obj; - return Objects.equals(result, other.result); + return Objects.equals(evaluationName, other.evaluationName) && Objects.equals(metrics, other.metrics); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java index 0089d2e04e894..c01c19e33e865 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -11,6 +11,8 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.search.builder.SearchSourceBuilder; +import java.util.List; + /** * Defines an evaluation */ @@ -29,7 +31,7 @@ public interface Evaluation extends ToXContentObject, NamedWriteable { /** * Computes the evaluation result * @param searchResponse The search response required to compute the result - * @param listener A listener of the result + * @param listener A listener of the results */ - void evaluate(SearchResponse searchResponse, ActionListener listener); + void evaluate(SearchResponse searchResponse, ActionListener> listener); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationResult.java deleted file mode 100644 index 60b2701ee14a2..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationResult.java +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.dataframe.evaluation; - -import org.elasticsearch.common.io.stream.NamedWriteable; -import org.elasticsearch.common.xcontent.ToXContentObject; - -/** - * The result of an evaluation - */ -public interface EvaluationResult extends ToXContentObject, NamedWriteable { - - /** - * Returns the name of the evaluation - */ - String getEvaluationName(); -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MetricListEvaluationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MetricListEvaluationResult.java deleted file mode 100644 index cd32e8aaa4594..0000000000000 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MetricListEvaluationResult.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.core.ml.dataframe.evaluation; - -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public class MetricListEvaluationResult implements EvaluationResult { - - public static final String NAME = "metric_list_evaluation_result"; - - private final String evaluationName; - private final List metrics; - - public MetricListEvaluationResult(String evaluationName, List metrics) { - this.evaluationName = Objects.requireNonNull(evaluationName); - this.metrics = Objects.requireNonNull(metrics); - } - - public MetricListEvaluationResult(StreamInput in) throws IOException { - this.evaluationName = in.readString(); - this.metrics = in.readNamedWriteableList(EvaluationMetricResult.class); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public String getEvaluationName() { - return evaluationName; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(evaluationName); - out.writeList(metrics); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - for (EvaluationMetricResult metric : metrics) { - builder.field(metric.getName(), metric); - } - builder.endObject(); - return builder; - } -} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java index a4b00840d2b1b..f4a6dba88e3b1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -46,10 +46,6 @@ public List getNamedWriteables() { namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), BinarySoftClassification::new)); - // Evaluation Results - namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationResult.class, MetricListEvaluationResult.NAME, - MetricListEvaluationResult::new)); - // Evaluation Metrics namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java index 27a922f086c53..732fcdfb44eae 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -22,8 +22,6 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationResult; -import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MetricListEvaluationResult; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -176,7 +174,7 @@ private QueryBuilder buildQuery() { } @Override - public void evaluate(SearchResponse searchResponse, ActionListener listener) { + public void evaluate(SearchResponse searchResponse, ActionListener> listener) { if (searchResponse.getHits().getTotalHits().value == 0) { listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, predictedProbabilityField)); @@ -189,7 +187,7 @@ public void evaluate(SearchResponse searchResponse, ActionListener { @@ -40,13 +42,13 @@ protected void doExecute(Task task, EvaluateDataFrameAction.Request request, SearchRequest searchRequest = new SearchRequest(request.getIndices()); searchRequest.source(evaluation.buildSearch()); - ActionListener resultListener = ActionListener.wrap( - result -> listener.onResponse(new EvaluateDataFrameAction.Response(result)), + ActionListener> resultsListener = ActionListener.wrap( + results -> listener.onResponse(new EvaluateDataFrameAction.Response(evaluation.getName(), results)), listener::onFailure ); client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap( - searchResponse -> threadPool.generic().execute(() -> evaluation.evaluate(searchResponse, resultListener)), + searchResponse -> threadPool.generic().execute(() -> evaluation.evaluate(searchResponse, resultsListener)), listener::onFailure )); } From 6aa3a979ca35947093d230e65d30e475ab49d871 Mon Sep 17 00:00:00 2001 From: Hendrik Muhs Date: Thu, 9 May 2019 14:49:49 +0200 Subject: [PATCH 47/67] [ML-DataFrame] timebased continuous dataframe (#41880) adds continuous dataframes based on timestamps in the source data for term groups --- .../core/dataframe/DataFrameMessages.java | 4 +- .../core/dataframe/transforms/SyncConfig.java | 4 +- .../dataframe/transforms/TimeSyncConfig.java | 6 +- .../pivot/DateHistogramGroupSource.java | 13 ++ .../pivot/HistogramGroupSource.java | 13 ++ .../transforms/pivot/PivotConfig.java | 6 +- .../transforms/pivot/SingleGroupSource.java | 6 + .../transforms/pivot/TermsGroupSource.java | 13 ++ .../DataFrameTransformsCheckpointService.java | 15 +- .../transforms/DataFrameIndexer.java | 142 +++++++++++++++++- ...FrameTransformPersistentTasksExecutor.java | 30 +++- .../transforms/DataFrameTransformTask.java | 70 ++++++++- .../dataframe/transforms/pivot/Pivot.java | 87 ++++++++++- .../transforms/DataFrameIndexerTests.java | 15 +- 14 files changed, 392 insertions(+), 32 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java index d31892692a5df..851042d1e67fe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java @@ -38,7 +38,9 @@ public class DataFrameMessages { public static final String FAILED_TO_PARSE_TRANSFORM_CONFIGURATION = "Failed to parse transform configuration for data frame transform [{0}]"; public static final String FAILED_TO_PARSE_TRANSFORM_STATISTICS_CONFIGURATION = - "Failed to parse transform statistics for data frame transform [{0}]"; + "Failed to parse transform statistics for data frame transform [{0}]"; + public static final String FAILED_TO_LOAD_TRANSFORM_CHECKPOINT = + "Failed to load data frame transform configuration for transform [{0}]"; public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_NO_TRANSFORM = "Data frame transform configuration must specify exactly 1 function"; public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY = diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java index d8008f12126e3..19ff79ea7e0ee 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java @@ -19,7 +19,7 @@ public interface SyncConfig extends ToXContentObject, NamedWriteable { */ boolean isValid(); - QueryBuilder getBoundaryQuery(DataFrameTransformCheckpoint checkpoint); + QueryBuilder getRangeQuery(DataFrameTransformCheckpoint newCheckpoint); - QueryBuilder getChangesQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint); + QueryBuilder getRangeQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java index 9a949ab21e8c3..0490394d90b26 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java @@ -136,12 +136,12 @@ public String getWriteableName() { } @Override - public QueryBuilder getBoundaryQuery(DataFrameTransformCheckpoint checkpoint) { - return new RangeQueryBuilder(field).lt(checkpoint.getTimeUpperBound()).format("epoch_millis"); + public QueryBuilder getRangeQuery(DataFrameTransformCheckpoint newCheckpoint) { + return new RangeQueryBuilder(field).lt(newCheckpoint.getTimeUpperBound()).format("epoch_millis"); } @Override - public QueryBuilder getChangesQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint) { + public QueryBuilder getRangeQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint) { return new RangeQueryBuilder(field).gte(oldCheckpoint.getTimeUpperBound()).lt(newCheckpoint.getTimeUpperBound()) .format("epoch_millis"); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java index f4bf094235ae4..8cf79268f769a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java @@ -12,12 +12,14 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; import java.io.IOException; import java.time.ZoneId; import java.time.ZoneOffset; import java.util.Objects; +import java.util.Set; public class DateHistogramGroupSource extends SingleGroupSource { @@ -179,4 +181,15 @@ public boolean equals(Object other) { public int hashCode() { return Objects.hash(field, interval, dateHistogramInterval, timeZone, format); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + // no need for an extra range filter as this is already done by checkpoints + return null; + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java index 737590a0cc197..372f4ad99b608 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java @@ -11,9 +11,11 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; import java.io.IOException; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -99,4 +101,15 @@ public boolean equals(Object other) { public int hashCode() { return Objects.hash(field, interval); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + // histograms are simple and cheap, so we skip this optimization + return null; + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java index c1c894e2971ae..c40e4bb23353d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java @@ -86,12 +86,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public void toCompositeAggXContent(XContentBuilder builder, Params params) throws IOException { + public void toCompositeAggXContent(XContentBuilder builder, boolean forChangeDetection) throws IOException { builder.startObject(); builder.field(CompositeAggregationBuilder.SOURCES_FIELD_NAME.getPreferredName()); builder.startArray(); for (Entry groupBy : groups.getGroups().entrySet()) { + // some group source do not implement change detection or not makes no sense, skip those + if (forChangeDetection && groupBy.getValue().supportsIncrementalBucketUpdate() == false) { + continue; + } builder.startObject(); builder.startObject(groupBy.getKey()); builder.field(groupBy.getValue().getType().value(), groupBy.getValue()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java index 0a4cf2579460e..ff1f9c3d54ac8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java @@ -14,10 +14,12 @@ import org.elasticsearch.common.xcontent.AbstractObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; import java.io.IOException; import java.util.Locale; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -94,6 +96,10 @@ public void writeTo(StreamOutput out) throws IOException { public abstract Type getType(); + public abstract boolean supportsIncrementalBucketUpdate(); + + public abstract QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets); + public String getField() { return field; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java index d4585a611b367..891b160da0762 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java @@ -9,8 +9,11 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.TermsQueryBuilder; import java.io.IOException; +import java.util.Set; /* * A terms aggregation source for group_by @@ -47,4 +50,14 @@ public Type getType() { public static TermsGroupSource fromXContent(final XContentParser parser, boolean lenient) throws IOException { return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + return new TermsQueryBuilder(field, changedBuckets); + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return true; + } } diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformsCheckpointService.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformsCheckpointService.java index 6fc2e334f9255..fad9836b760d8 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformsCheckpointService.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/checkpoint/DataFrameTransformsCheckpointService.java @@ -20,6 +20,8 @@ import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointStats; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpointingInfo; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.elasticsearch.xpack.dataframe.persistence.DataFrameTransformsConfigManager; import java.util.Arrays; @@ -84,8 +86,8 @@ public void getCheckpoint(DataFrameTransformConfig transformConfig, long checkpo ActionListener listener) { long timestamp = System.currentTimeMillis(); - // placeholder for time based synchronization - long timeUpperBound = 0; + // for time based synchronization + long timeUpperBound = getTimeStampForTimeBasedSynchronization(transformConfig.getSyncConfig(), timestamp); // 1st get index to see the indexes the user has access to GetIndexRequest getIndexRequest = new GetIndexRequest() @@ -205,6 +207,15 @@ public void getCheckpointStats( ); } + private long getTimeStampForTimeBasedSynchronization(SyncConfig syncConfig, long timestamp) { + if (syncConfig instanceof TimeSyncConfig) { + TimeSyncConfig timeSyncConfig = (TimeSyncConfig) syncConfig; + return timestamp - timeSyncConfig.getDelay().millis(); + } + + return 0L; + } + static Map extractIndexCheckPoints(ShardStats[] shards, Set userIndices) { Map> checkpointsByIndex = new TreeMap<>(); diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexer.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexer.java index f2fc71da7f059..9234ce8ab5dc3 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexer.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexer.java @@ -16,10 +16,15 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregation; +import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.dataframe.DataFrameField; import org.elasticsearch.xpack.core.dataframe.DataFrameMessages; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats; +import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformProgress; import org.elasticsearch.xpack.core.dataframe.utils.ExceptionsHelper; @@ -34,6 +39,7 @@ import java.util.Collections; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -55,6 +61,8 @@ public abstract class DataFrameIndexer extends AsyncTwoPhaseIndexer> changedBuckets; public DataFrameIndexer(Executor executor, DataFrameAuditor auditor, @@ -63,12 +71,14 @@ public DataFrameIndexer(Executor executor, AtomicReference initialState, Map initialPosition, DataFrameIndexerTransformStats jobStats, - DataFrameTransformProgress transformProgress) { + DataFrameTransformProgress transformProgress, + DataFrameTransformCheckpoint inProgressOrLastCheckpoint) { super(executor, initialState, initialPosition, jobStats); this.auditor = Objects.requireNonNull(auditor); this.transformConfig = ExceptionsHelper.requireNonNull(transformConfig, "transformConfig"); this.fieldMappings = ExceptionsHelper.requireNonNull(fieldMappings, "fieldMappings"); this.progress = transformProgress; + this.inProgressOrLastCheckpoint = inProgressOrLastCheckpoint; } protected abstract void failIndexer(String message); @@ -81,6 +91,10 @@ public DataFrameTransformConfig getConfig() { return transformConfig; } + public boolean isContinuous() { + return getConfig().getSyncConfig() != null; + } + public Map getFieldMappings() { return fieldMappings; } @@ -92,7 +106,7 @@ public DataFrameTransformProgress getProgress() { /** * Request a checkpoint */ - protected abstract void createCheckpoint(ActionListener listener); + protected abstract void createCheckpoint(ActionListener listener); @Override protected void onStart(long now, ActionListener listener) { @@ -106,7 +120,23 @@ protected void onStart(long now, ActionListener listener) { // if run for the 1st time, create checkpoint if (initialRun()) { - createCheckpoint(listener); + createCheckpoint(ActionListener.wrap(cp -> { + DataFrameTransformCheckpoint oldCheckpoint = inProgressOrLastCheckpoint; + + if (oldCheckpoint.isEmpty()) { + // this is the 1st run, accept the new in progress checkpoint and go on + inProgressOrLastCheckpoint = cp; + listener.onResponse(null); + } else { + logger.debug ("Getting changes from {} to {}", oldCheckpoint.getTimeUpperBound(), cp.getTimeUpperBound()); + + getChangedBuckets(oldCheckpoint, cp, ActionListener.wrap(changedBuckets -> { + inProgressOrLastCheckpoint = cp; + this.changedBuckets = changedBuckets; + listener.onResponse(null); + }, listener::onFailure)); + } + }, listener::onFailure)); } else { listener.onResponse(null); } @@ -123,6 +153,8 @@ protected boolean initialRun() { protected void onFinish(ActionListener listener) { // reset the page size, so we do not memorize a low page size forever, the pagesize will be re-calculated on start pageSize = 0; + // reset the changed bucket to free memory + changedBuckets = null; } @Override @@ -185,7 +217,38 @@ private Stream processBucketsToIndexRequests(CompositeAggregation @Override protected SearchRequest buildSearchRequest() { - return pivot.buildSearchRequest(getConfig().getSource(), getPosition(), pageSize); + SearchRequest searchRequest = new SearchRequest(getConfig().getSource().getIndex()); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.aggregation(pivot.buildAggregation(getPosition(), pageSize)); + sourceBuilder.size(0); + + QueryBuilder pivotQueryBuilder = getConfig().getSource().getQueryConfig().getQuery(); + + DataFrameTransformConfig config = getConfig(); + if (config.getSyncConfig() != null) { + if (inProgressOrLastCheckpoint == null) { + throw new RuntimeException("in progress checkpoint not found"); + } + + BoolQueryBuilder filteredQuery = new BoolQueryBuilder(). + filter(pivotQueryBuilder). + filter(config.getSyncConfig().getRangeQuery(inProgressOrLastCheckpoint)); + + if (changedBuckets != null && changedBuckets.isEmpty() == false) { + QueryBuilder pivotFilter = pivot.filterBuckets(changedBuckets); + if (pivotFilter != null) { + filteredQuery.filter(pivotFilter); + } + } + + logger.trace("running filtered query: {}", filteredQuery); + sourceBuilder.query(filteredQuery); + } else { + sourceBuilder.query(pivotQueryBuilder); + } + + searchRequest.source(sourceBuilder); + return searchRequest; } /** @@ -226,6 +289,75 @@ protected boolean handleCircuitBreakingException(Exception e) { return true; } + private void getChangedBuckets(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint, + ActionListener>> listener) { + + // initialize the map of changed buckets, the map might be empty if source do not require/implement + // changed bucket detection + Map> keys = pivot.initialIncrementalBucketUpdateMap(); + if (keys.isEmpty()) { + logger.trace("This data frame does not implement changed bucket detection, returning"); + listener.onResponse(null); + return; + } + + SearchRequest searchRequest = new SearchRequest(getConfig().getSource().getIndex()); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + + // we do not need the sub-aggs + CompositeAggregationBuilder changesAgg = pivot.buildIncrementalBucketUpdateAggregation(pageSize); + sourceBuilder.aggregation(changesAgg); + sourceBuilder.size(0); + + QueryBuilder pivotQueryBuilder = getConfig().getSource().getQueryConfig().getQuery(); + + DataFrameTransformConfig config = getConfig(); + if (config.getSyncConfig() != null) { + BoolQueryBuilder filteredQuery = new BoolQueryBuilder(). + filter(pivotQueryBuilder). + filter(config.getSyncConfig().getRangeQuery(oldCheckpoint, newCheckpoint)); + + logger.trace("Gathering changes using query {}", filteredQuery); + sourceBuilder.query(filteredQuery); + } else { + logger.trace("No sync configured"); + listener.onResponse(null); + return; + } + + searchRequest.source(sourceBuilder); + searchRequest.allowPartialSearchResults(false); + + collectChangedBuckets(searchRequest, changesAgg, keys, ActionListener.wrap(listener::onResponse, e -> { + // fall back if bucket collection failed + logger.error("Failed to retrieve changed buckets, fall back to complete retrieval", e); + listener.onResponse(null); + })); + } + + void collectChangedBuckets(SearchRequest searchRequest, CompositeAggregationBuilder changesAgg, Map> keys, + ActionListener>> finalListener) { + + // re-using the existing search hook + doNextSearch(searchRequest, ActionListener.wrap(searchResponse -> { + final CompositeAggregation agg = searchResponse.getAggregations().get(COMPOSITE_AGGREGATION_NAME); + + agg.getBuckets().stream().forEach(bucket -> { + bucket.getKey().forEach((k, v) -> { + keys.get(k).add(v.toString()); + }); + }); + + if (agg.getBuckets().isEmpty()) { + finalListener.onResponse(keys); + } else { + // adjust the after key + changesAgg.aggregateAfter(agg.afterKey()); + collectChangedBuckets(searchRequest, changesAgg, keys, finalListener); + } + }, finalListener::onFailure)); + } + /** * Inspect exception for circuit breaking exception and return the first one it can find. * @@ -251,4 +383,6 @@ private static CircuitBreakingException getCircuitBreakingException(Exception e) return null; } + + protected abstract boolean sourceHasChanged(); } diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformPersistentTasksExecutor.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformPersistentTasksExecutor.java index d0f15197c3cca..1809ae50c8cd9 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformPersistentTasksExecutor.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformPersistentTasksExecutor.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.dataframe.action.StartDataFrameTransformTaskAction; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransform; +import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformState; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformTaskState; @@ -130,7 +131,7 @@ protected void nodeOperation(AllocatedPersistentTask task, @Nullable DataFrameTr failure -> logger.error("Failed to start task ["+ transformId +"] in node operation", failure) ); - // <3> Set the previous stats (if they exist), initialize the indexer, start the task (If it is STOPPED) + // <4> Set the previous stats (if they exist), initialize the indexer, start the task (If it is STOPPED) // Since we don't create the task until `_start` is called, if we see that the task state is stopped, attempt to start // Schedule execution regardless ActionListener transformStatsActionListener = ActionListener.wrap( @@ -149,11 +150,34 @@ protected void nodeOperation(AllocatedPersistentTask task, @Nullable DataFrameTr } ); + // <3> set the in progress checkpoint for the indexer, get the in progress checkpoint + ActionListener getTransformCheckpointListener = ActionListener.wrap( + cp -> { + indexerBuilder.setInProgressOrLastCheckpoint(cp); + transformsConfigManager.getTransformStats(transformId, transformStatsActionListener); + }, + error -> { + String msg = DataFrameMessages.getMessage(DataFrameMessages.FAILED_TO_LOAD_TRANSFORM_CHECKPOINT, transformId); + logger.error(msg, error); + markAsFailed(buildTask, msg); + } + ); + // <2> set fieldmappings for the indexer, get the previous stats (if they exist) ActionListener> getFieldMappingsListener = ActionListener.wrap( fieldMappings -> { indexerBuilder.setFieldMappings(fieldMappings); - transformsConfigManager.getTransformStats(transformId, transformStatsActionListener); + + long inProgressCheckpoint = transformState == null ? 0L : + Math.max(transformState.getCheckpoint(), transformState.getInProgressCheckpoint()); + + logger.debug("Restore in progress or last checkpoint: {}", inProgressCheckpoint); + + if (inProgressCheckpoint == 0) { + getTransformCheckpointListener.onResponse(DataFrameTransformCheckpoint.EMPTY); + } else { + transformsConfigManager.getTransformCheckpoint(transformId, inProgressCheckpoint, getTransformCheckpointListener); + } }, error -> { String msg = DataFrameMessages.getMessage(DataFrameMessages.DATA_FRAME_UNABLE_TO_GATHER_FIELD_MAPPINGS, @@ -238,7 +262,7 @@ private void scheduleAndStartTask(DataFrameTransformTask buildTask, static SchedulerEngine.Schedule next() { return (startTime, now) -> { - return now + 1000; // to be fixed, hardcode something + return now + 10000; // to be fixed, hardcode something }; } diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java index 2020300a0cf77..6a36c42655375 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.LatchedActionListener; import org.elasticsearch.action.bulk.BulkAction; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkResponse; @@ -30,6 +31,7 @@ import org.elasticsearch.xpack.core.dataframe.action.StopDataFrameTransformAction; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransform; +import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformProgress; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformState; @@ -44,6 +46,8 @@ import java.util.Arrays; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; @@ -271,10 +275,17 @@ public synchronized void triggered(Event event) { logger.warn("Data frame task [{}] triggered with an unintialized indexer", getTransformId()); return; } - // for now no rerun, so only trigger if checkpoint == 0 - if (currentCheckpoint.get() == 0 && event.getJobName().equals(SCHEDULE_NAME + "_" + transform.getId())) { - logger.debug("Data frame indexer [{}] schedule has triggered, state: [{}]", event.getJobName(), getIndexer().getState()); - getIndexer().maybeTriggerAsyncJob(System.currentTimeMillis()); + if (event.getJobName().equals(SCHEDULE_NAME + "_" + transform.getId())) { + logger.info("Data frame indexer [{}] schedule has triggered, state: [{}]", event.getJobName(), getIndexer().getState()); + + // if it runs for the 1st time we just do it, if not we check for changes + if (currentCheckpoint.get() == 0 ) { + logger.debug("Trigger initial run"); + getIndexer().maybeTriggerAsyncJob(System.currentTimeMillis()); + } else if (getIndexer().isContinuous() && getIndexer().sourceHasChanged()) { + logger.debug("Source has changed, triggering new indexer run"); + getIndexer().maybeTriggerAsyncJob(System.currentTimeMillis()); + } } } @@ -363,6 +374,7 @@ static class ClientDataFrameIndexerBuilder { private IndexerState indexerState = IndexerState.STOPPED; private Map initialPosition; private DataFrameTransformProgress progress; + private DataFrameTransformCheckpoint inProgressOrLastCheckpoint; ClientDataFrameIndexer build(DataFrameTransformTask parentTask) { return new ClientDataFrameIndexer(this.transformId, @@ -376,6 +388,7 @@ ClientDataFrameIndexer build(DataFrameTransformTask parentTask) { this.transformConfig, this.fieldMappings, this.progress, + this.inProgressOrLastCheckpoint, parentTask); } @@ -437,6 +450,11 @@ ClientDataFrameIndexerBuilder setProgress(DataFrameTransformProgress progress) { this.progress = progress; return this; } + + ClientDataFrameIndexerBuilder setInProgressOrLastCheckpoint(DataFrameTransformCheckpoint inProgressOrLastCheckpoint) { + this.inProgressOrLastCheckpoint = inProgressOrLastCheckpoint; + return this; + } } static class ClientDataFrameIndexer extends DataFrameIndexer { @@ -461,6 +479,7 @@ static class ClientDataFrameIndexer extends DataFrameIndexer { DataFrameTransformConfig transformConfig, Map fieldMappings, DataFrameTransformProgress transformProgress, + DataFrameTransformCheckpoint inProgressOrLastCheckpoint, DataFrameTransformTask parentTask) { super(ExceptionsHelper.requireNonNull(parentTask, "parentTask") .threadPool @@ -471,7 +490,8 @@ static class ClientDataFrameIndexer extends DataFrameIndexer { ExceptionsHelper.requireNonNull(initialState, "initialState"), initialPosition, initialStats == null ? new DataFrameIndexerTransformStats(transformId) : initialStats, - transformProgress); + transformProgress, + inProgressOrLastCheckpoint); this.transformId = ExceptionsHelper.requireNonNull(transformId, "transformId"); this.transformsConfigManager = ExceptionsHelper.requireNonNull(transformsConfigManager, "transformsConfigManager"); this.transformsCheckpointService = ExceptionsHelper.requireNonNull(transformsCheckpointService, @@ -620,13 +640,13 @@ protected void onAbort() { } @Override - protected void createCheckpoint(ActionListener listener) { + protected void createCheckpoint(ActionListener listener) { transformsCheckpointService.getCheckpoint(transformConfig, transformTask.currentCheckpoint.get() + 1, ActionListener.wrap( checkpoint -> transformsConfigManager.putTransformCheckpoint(checkpoint, ActionListener.wrap( - putCheckPointResponse -> listener.onResponse(null), + putCheckPointResponse -> listener.onResponse(checkpoint), createCheckpointException -> listener.onFailure(new RuntimeException("Failed to create checkpoint", createCheckpointException)) )), @@ -635,6 +655,42 @@ protected void createCheckpoint(ActionListener listener) { )); } + @Override + public boolean sourceHasChanged() { + if (getState() == IndexerState.INDEXING) { + logger.trace("Indexer is still running, ignore"); + return false; + } + + CountDownLatch latch = new CountDownLatch(1); + + SetOnce changed = new SetOnce<>(); + transformsCheckpointService.getCheckpoint(transformConfig, new LatchedActionListener<>(ActionListener.wrap( + cp -> { + long behind = DataFrameTransformCheckpoint.getBehind(inProgressOrLastCheckpoint, cp); + if (behind > 0) { + logger.debug("Detected changes, dest is {} operations behind the source", behind); + changed.set(true); + } else { + changed.set(false); + } + }, e -> { + changed.set(false); + logger.error("failure in update check", e); + }), latch)); + + try { + if (latch.await(5, TimeUnit.SECONDS)) { + logger.trace("Change detected:" + changed.get()); + return changed.get(); + } + } catch (InterruptedException e) { + logger.error("Failed to check for update", e); + } + + return false; + } + private boolean isIrrecoverableFailure(Exception e) { return e instanceof IndexNotFoundException; } diff --git a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Pivot.java b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Pivot.java index 0e5231442d18b..273d32f47c62a 100644 --- a/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Pivot.java +++ b/x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/pivot/Pivot.java @@ -6,6 +6,8 @@ package org.elasticsearch.xpack.dataframe.transforms.pivot; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; @@ -13,9 +15,9 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.aggregations.AggregationBuilder; @@ -28,10 +30,15 @@ import org.elasticsearch.xpack.core.dataframe.transforms.SourceConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.GroupConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.pivot.SingleGroupSource; import java.io.IOException; import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; import java.util.stream.Stream; import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; @@ -41,6 +48,7 @@ public class Pivot { public static final int TEST_QUERY_PAGE_SIZE = 50; private static final String COMPOSITE_AGGREGATION_NAME = "_data_frame"; + private static final Logger logger = LogManager.getLogger(Pivot.class); private final PivotConfig config; @@ -105,6 +113,26 @@ public AggregationBuilder buildAggregation(Map position, int pag return cachedCompositeAggregation; } + public CompositeAggregationBuilder buildIncrementalBucketUpdateAggregation(int pageSize) { + + CompositeAggregationBuilder compositeAgg = createCompositeAggregationSources(config, true); + compositeAgg.size(pageSize); + + return compositeAgg; + } + + public Map> initialIncrementalBucketUpdateMap() { + + Map> changedBuckets = new HashMap<>(); + for(Entry entry: config.getGroupConfig().getGroups().entrySet()) { + if (entry.getValue().supportsIncrementalBucketUpdate()) { + changedBuckets.put(entry.getKey(), new HashSet<>()); + } + } + + return changedBuckets; + } + public Stream> extractResults(CompositeAggregation agg, Map fieldTypeMap, DataFrameIndexerTransformStats dataFrameIndexerTransformStats) { @@ -139,17 +167,66 @@ private void runTestQuery(Client client, SourceConfig sourceConfig, final Action })); } + public QueryBuilder filterBuckets(Map> changedBuckets) { + + if (changedBuckets == null || changedBuckets.isEmpty()) { + return null; + } + + if (config.getGroupConfig().getGroups().size() == 1) { + Entry entry = config.getGroupConfig().getGroups().entrySet().iterator().next(); + // it should not be possible to get into this code path + assert (entry.getValue().supportsIncrementalBucketUpdate()); + + logger.trace("filter by bucket: " + entry.getKey() + "/" + entry.getValue().getField()); + if (changedBuckets.containsKey(entry.getKey())) { + return entry.getValue().getIncrementalBucketUpdateFilterQuery(changedBuckets.get(entry.getKey())); + } else { + // should never happen + throw new RuntimeException("Could not find bucket value for key " + entry.getKey()); + } + } + + // else: more than 1 group by, need to nest it + BoolQueryBuilder filteredQuery = new BoolQueryBuilder(); + for (Entry entry : config.getGroupConfig().getGroups().entrySet()) { + if (entry.getValue().supportsIncrementalBucketUpdate() == false) { + continue; + } + + if (changedBuckets.containsKey(entry.getKey())) { + QueryBuilder sourceQueryFilter = entry.getValue().getIncrementalBucketUpdateFilterQuery(changedBuckets.get(entry.getKey())); + // the source might not define an filter optimization + if (sourceQueryFilter != null) { + filteredQuery.filter(sourceQueryFilter); + } + } else { + // should never happen + throw new RuntimeException("Could not find bucket value for key " + entry.getKey()); + } + + } + + return filteredQuery; + } + private static CompositeAggregationBuilder createCompositeAggregation(PivotConfig config) { + final CompositeAggregationBuilder compositeAggregation = createCompositeAggregationSources(config, false); + + config.getAggregationConfig().getAggregatorFactories().forEach(agg -> compositeAggregation.subAggregation(agg)); + config.getAggregationConfig().getPipelineAggregatorFactories().forEach(agg -> compositeAggregation.subAggregation(agg)); + + return compositeAggregation; + } + + private static CompositeAggregationBuilder createCompositeAggregationSources(PivotConfig config, boolean forChangeDetection) { CompositeAggregationBuilder compositeAggregation; try (XContentBuilder builder = jsonBuilder()) { - // write configuration for composite aggs into builder - config.toCompositeAggXContent(builder, ToXContentObject.EMPTY_PARAMS); + config.toCompositeAggXContent(builder, forChangeDetection); XContentParser parser = builder.generator().contentType().xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, BytesReference.bytes(builder).streamInput()); compositeAggregation = CompositeAggregationBuilder.parse(COMPOSITE_AGGREGATION_NAME, parser); - config.getAggregationConfig().getAggregatorFactories().forEach(agg -> compositeAggregation.subAggregation(agg)); - config.getAggregationConfig().getPipelineAggregatorFactories().forEach(agg -> compositeAggregation.subAggregation(agg)); } catch (IOException e) { throw new RuntimeException(DataFrameMessages.DATA_FRAME_TRANSFORM_PIVOT_FAILED_TO_CREATE_COMPOSITE_AGGREGATION, e); } diff --git a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexerTests.java b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexerTests.java index f3f3255f07a6d..4bfe5b85b20b4 100644 --- a/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexerTests.java +++ b/x-pack/plugin/data-frame/src/test/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexerTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameIndexerTransformStats; +import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformCheckpoint; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfigTests; import org.elasticsearch.xpack.core.indexing.IndexerState; @@ -68,7 +69,7 @@ class MockedDataFrameIndexer extends DataFrameIndexer { Function bulkFunction, Consumer failureConsumer) { super(executor, auditor, transformConfig, fieldMappings, initialState, initialPosition, jobStats, - /* DataFrameTransformProgress */ null); + /* DataFrameTransformProgress */ null, DataFrameTransformCheckpoint.EMPTY); this.searchFunction = searchFunction; this.bulkFunction = bulkFunction; this.failureConsumer = failureConsumer; @@ -79,8 +80,8 @@ public CountDownLatch newLatch(int count) { } @Override - protected void createCheckpoint(ActionListener listener) { - listener.onResponse(null); + protected void createCheckpoint(ActionListener listener) { + listener.onResponse(DataFrameTransformCheckpoint.EMPTY); } @Override @@ -158,6 +159,11 @@ protected void failIndexer(String message) { fail("failIndexer should not be called, received error: " + message); } + @Override + protected boolean sourceHasChanged() { + return false; + } + } @Before @@ -180,7 +186,7 @@ public void testPageSizeAdapt() throws InterruptedException { Function bulkFunction = bulkRequest -> new BulkResponse(new BulkItemResponse[0], 100); Consumer failureConsumer = e -> { - fail("expected circuit breaker exception to be handled"); + fail("expected circuit breaker exception to be handled, got " + e); }; final ExecutorService executor = Executors.newFixedThreadPool(1); @@ -218,4 +224,5 @@ public void testPageSizeAdapt() throws InterruptedException { executor.shutdownNow(); } } + } From e07b78cca69bdb81ae7c93fdb6055bcf64904386 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 24 May 2019 14:32:09 +0300 Subject: [PATCH 48/67] [ML] QueryProvider.fromStream should be public --- .../org/elasticsearch/xpack/core/ml/utils/QueryProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java index f470f2e26117a..3fe0ba70331a0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java @@ -68,7 +68,7 @@ public static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOE null); } - static QueryProvider fromStream(StreamInput in) throws IOException { + public static QueryProvider fromStream(StreamInput in) throws IOException { return new QueryProvider(in.readMap(), in.readOptionalNamedWriteable(QueryBuilder.class), in.readException()); } From 6d05f8f754c482bc48bfb110cacc544c54345552 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 27 May 2019 20:38:01 +0300 Subject: [PATCH 49/67] =?UTF-8?q?[FEATURE][ML]=20Allow=20configuring=20out?= =?UTF-8?q?lier=20score=20threshold=20for=20feature=20i=E2=80=A6=20(#42239?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../client/ml/dataframe/OutlierDetection.java | 32 ++++++++++++++----- .../ml/dataframe/OutlierDetectionTests.java | 5 +++ .../dataframe/analyses/OutlierDetection.java | 32 ++++++++++++++++--- .../persistence/ElasticsearchMappings.java | 3 ++ .../ml/job/results/ReservedFieldNames.java | 1 + .../analyses/OutlierDetectionTests.java | 10 ++++-- .../test/ml/data_frame_analytics_crud.yml | 30 +++++++++++++++++ 7 files changed, 97 insertions(+), 16 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java index bb0ecff6865ed..946c01ac5c835 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java @@ -47,6 +47,8 @@ public static Builder builder() { public static final ParseField NAME = new ParseField("outlier_detection"); static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); static final ParseField METHOD = new ParseField("method"); + public static final ParseField MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE = + new ParseField("minimum_score_to_write_feature_influence"); private static ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Builder::new); @@ -58,23 +60,23 @@ public static Builder builder() { } throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); }, METHOD, ObjectParser.ValueType.STRING); + PARSER.declareDouble(Builder::setMinScoreToWriteFeatureInfluence, MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE); } private final Integer nNeighbors; private final Method method; + private final Double minScoreToWriteFeatureInfluence; /** * Constructs the outlier detection configuration * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. * @param method The method. Leave unspecified for a dynamic mixture of methods. + * @param minScoreToWriteFeatureInfluence The min outlier score required to calculate feature influence. Defaults to 0.1. */ - private OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method) { - if (nNeighbors != null && nNeighbors <= 0) { - throw new IllegalArgumentException("[" + N_NEIGHBORS.getPreferredName() + "] must be a positive integer"); - } - + private OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method, @Nullable Double minScoreToWriteFeatureInfluence) { this.nNeighbors = nNeighbors; this.method = method; + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; } @Override @@ -90,6 +92,10 @@ public Method getMethod() { return method; } + public Double getMinScoreToWriteFeatureInfluence() { + return minScoreToWriteFeatureInfluence; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -99,6 +105,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (method != null) { builder.field(METHOD.getPreferredName(), method); } + if (minScoreToWriteFeatureInfluence != null) { + builder.field(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } builder.endObject(); return builder; } @@ -110,12 +119,13 @@ public boolean equals(Object o) { OutlierDetection other = (OutlierDetection) o; return Objects.equals(nNeighbors, other.nNeighbors) - && Objects.equals(method, other.method); + && Objects.equals(method, other.method) + && Objects.equals(minScoreToWriteFeatureInfluence, other.minScoreToWriteFeatureInfluence); } @Override public int hashCode() { - return Objects.hash(nNeighbors, method); + return Objects.hash(nNeighbors, method, minScoreToWriteFeatureInfluence); } @Override @@ -140,6 +150,7 @@ public static class Builder { private Integer nNeighbors; private Method method; + private Double minScoreToWriteFeatureInfluence; private Builder() {} @@ -153,8 +164,13 @@ public Builder setMethod(Method method) { return this; } + public Builder setMinScoreToWriteFeatureInfluence(Double minScoreToWriteFeatureInfluence) { + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; + return this; + } + public OutlierDetection build() { - return new OutlierDetection(nNeighbors, method); + return new OutlierDetection(nNeighbors, method, minScoreToWriteFeatureInfluence); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java index 9eda15b04e813..de110d92fdee1 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java @@ -24,6 +24,7 @@ import java.io.IOException; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; public class OutlierDetectionTests extends AbstractXContentTestCase { @@ -32,6 +33,7 @@ public static OutlierDetection randomOutlierDetection() { return OutlierDetection.builder() .setNNeighbors(randomBoolean() ? null : randomIntBetween(1, 20)) .setMethod(randomBoolean() ? null : randomFrom(OutlierDetection.Method.values())) + .setMinScoreToWriteFeatureInfluence(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, true)) .build(); } @@ -54,6 +56,7 @@ public void testGetParams_GivenDefaults() { OutlierDetection outlierDetection = OutlierDetection.createDefault(); assertNull(outlierDetection.getNNeighbors()); assertNull(outlierDetection.getMethod()); + assertNull(outlierDetection.getMinScoreToWriteFeatureInfluence()); } public void testGetParams_GivenExplicitValues() { @@ -61,8 +64,10 @@ public void testGetParams_GivenExplicitValues() { OutlierDetection.builder() .setNNeighbors(42) .setMethod(OutlierDetection.Method.LDOF) + .setMinScoreToWriteFeatureInfluence(0.5) .build(); assertThat(outlierDetection.getNNeighbors(), equalTo(42)); assertThat(outlierDetection.getMethod(), equalTo(OutlierDetection.Method.LDOF)); + assertThat(outlierDetection.getMinScoreToWriteFeatureInfluence(), closeTo(0.5, 1E-9)); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 5cd00fa979550..91eb02b7bcdfe 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -27,13 +27,15 @@ public class OutlierDetection implements DataFrameAnalysis { public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); public static final ParseField METHOD = new ParseField("method"); + public static final ParseField MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE = + new ParseField("minimum_score_to_write_feature_influence"); private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); private static ConstructingObjectParser createParser(boolean lenient) { ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient, - a -> new OutlierDetection((Integer) a[0], (Method) a[1])); + a -> new OutlierDetection((Integer) a[0], (Method) a[1], (Double) a[2])); parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), N_NEIGHBORS); parser.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { if (p.currentToken() == XContentParser.Token.VALUE_STRING) { @@ -41,6 +43,7 @@ private static ConstructingObjectParser createParser(boo } throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); }, METHOD, ObjectParser.ValueType.STRING); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE); return parser; } @@ -50,31 +53,40 @@ public static OutlierDetection fromXContent(XContentParser parser, boolean ignor private final Integer nNeighbors; private final Method method; + private final Double minScoreToWriteFeatureInfluence; /** * Constructs the outlier detection configuration * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. * @param method The method. Leave unspecified for a dynamic mixture of methods. + * @param minScoreToWriteFeatureInfluence The min outlier score required to calculate feature influence. Defaults to 0.1. */ - public OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method) { + public OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method, @Nullable Double minScoreToWriteFeatureInfluence) { if (nNeighbors != null && nNeighbors <= 0) { throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", N_NEIGHBORS.getPreferredName()); } + if (minScoreToWriteFeatureInfluence != null && (minScoreToWriteFeatureInfluence < 0.0 || minScoreToWriteFeatureInfluence > 1.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be in [0, 1]", + MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName()); + } + this.nNeighbors = nNeighbors; this.method = method; + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; } /** * Constructs the default outlier detection configuration */ public OutlierDetection() { - this(null, null); + this(null, null, null); } public OutlierDetection(StreamInput in) throws IOException { nNeighbors = in.readOptionalVInt(); method = in.readBoolean() ? in.readEnum(Method.class) : null; + minScoreToWriteFeatureInfluence = in.readOptionalDouble(); } @Override @@ -92,6 +104,8 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + + out.writeOptionalDouble(minScoreToWriteFeatureInfluence); } @Override @@ -103,6 +117,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (method != null) { builder.field(METHOD.getPreferredName(), method); } + if (minScoreToWriteFeatureInfluence != null) { + builder.field(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } builder.endObject(); return builder; } @@ -112,12 +129,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; OutlierDetection that = (OutlierDetection) o; - return Objects.equals(nNeighbors, that.nNeighbors) && Objects.equals(method, that.method); + return Objects.equals(nNeighbors, that.nNeighbors) + && Objects.equals(method, that.method) + && Objects.equals(minScoreToWriteFeatureInfluence, that.minScoreToWriteFeatureInfluence); } @Override public int hashCode() { - return Objects.hash(nNeighbors, method); + return Objects.hash(nNeighbors, method, minScoreToWriteFeatureInfluence); } @Override @@ -129,6 +148,9 @@ public Map getParams() { if (method != null) { params.put(METHOD.getPreferredName(), method); } + if (minScoreToWriteFeatureInfluence != null) { + params.put(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } return params; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index ac0c4a420448b..82a877f6f7445 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -427,6 +427,9 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(OutlierDetection.METHOD.getPreferredName()) .field(TYPE, KEYWORD) .endObject() + .startObject(OutlierDetection.MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() .endObject() .endObject() .endObject() diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index f727f637b972b..6c60acfbe39ff 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -277,6 +277,7 @@ public final class ReservedFieldNames { OutlierDetection.NAME.getPreferredName(), OutlierDetection.N_NEIGHBORS.getPreferredName(), OutlierDetection.METHOD.getPreferredName(), + OutlierDetection.MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java index 0e3d826593258..d7a3269597101 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java @@ -12,6 +12,7 @@ import java.io.IOException; import java.util.Map; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -30,7 +31,8 @@ protected OutlierDetection createTestInstance() { public static OutlierDetection createRandom() { Integer numberNeighbors = randomBoolean() ? null : randomIntBetween(1, 20); OutlierDetection.Method method = randomBoolean() ? null : randomFrom(OutlierDetection.Method.values()); - return new OutlierDetection(numberNeighbors, method); + Double minScoreToWriteFeatureInfluence = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, true); + return new OutlierDetection(numberNeighbors, method, minScoreToWriteFeatureInfluence); } @Override @@ -44,12 +46,14 @@ public void testGetParams_GivenDefaults() { } public void testGetParams_GivenExplicitValues() { - OutlierDetection outlierDetection = new OutlierDetection(42, OutlierDetection.Method.LDOF); + OutlierDetection outlierDetection = new OutlierDetection(42, OutlierDetection.Method.LDOF, 0.42); Map params = outlierDetection.getParams(); - assertThat(params.size(), equalTo(2)); + assertThat(params.size(), equalTo(3)); assertThat(params.get(OutlierDetection.N_NEIGHBORS.getPreferredName()), equalTo(42)); assertThat(params.get(OutlierDetection.METHOD.getPreferredName()), equalTo(OutlierDetection.Method.LDOF)); + assertThat((Double) params.get(OutlierDetection.MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName()), + is(closeTo(0.42, 1E-9))); } } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 1cf450b26d669..5dc265a74da22 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -91,6 +91,36 @@ - match: { dest.index: "index-dest" } - match: { analysis: {"outlier_detection":{}} } +--- +"Test put valid config with custom outlier detection": + + - do: + ml.put_data_frame_analytics: + id: "custom-outlier-detection" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": { + "outlier_detection":{ + "n_neighbors": 5, + "method": "lof", + "minimum_score_to_write_feature_influence": 0.0 + } + } + } + - match: { id: "custom-outlier-detection" } + - match: { source.index: "index-source" } + - match: { source.query: {"match_all" : {} } } + - match: { dest.index: "index-dest" } + - match: { analysis.outlier_detection.n_neighbors: 5 } + - match: { analysis.outlier_detection.method: "lof" } + - match: { analysis.outlier_detection.minimum_score_to_write_feature_influence: 0.0 } + --- "Test put config with inconsistent body/param ids": From eed4319ce8b9ba553546696fe35de0ebc7830cb6 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 29 May 2019 19:36:14 +0300 Subject: [PATCH 50/67] [ML] Fix compilation after merging master --- .../xpack/ml/action/TransportOpenJobActionTests.java | 3 +++ .../org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java index 6a20a19fab8b1..cc9a0ba0181ad 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java @@ -48,6 +48,9 @@ import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; import org.elasticsearch.xpack.core.ml.notifications.AuditorField; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import java.util.ArrayList; import java.util.Collections; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java index 9395f59d15601..f26dd3f81f6de 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java @@ -445,7 +445,7 @@ public void testSelectLeastLoadedMlNode_noNodesMatchingModelSnapshotMinVersion() Job job = BaseMlIntegTestCase.createFareQuoteJob("job_with_incompatible_model_snapshot") .setModelSnapshotId("incompatible_snapshot") - .setModelSnapshotMinVersion(Version.V_6_3_0) + .setModelSnapshotMinVersion(Version.fromString("6.3.0")) .build(new Date()); cs.nodes(nodes); metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); From f27886a841120ce4ccef5844710233a83607b20c Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 30 May 2019 17:33:43 +0300 Subject: [PATCH 51/67] [ML] Remove argument from request converter params constructor --- .../java/org/elasticsearch/client/MLRequestConverters.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 50bb9f55d71e8..e88658a5fccd5 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -602,7 +602,7 @@ static Request getDataFrameAnalytics(GetDataFrameAnalyticsRequest getRequest) { .addPathPart(Strings.collectionToCommaDelimitedString(getRequest.getIds())) .build(); Request request = new Request(HttpGet.METHOD_NAME, endpoint); - RequestConverters.Params params = new RequestConverters.Params(request); + RequestConverters.Params params = new RequestConverters.Params(); if (getRequest.getPageParams() != null) { PageParams pageParams = getRequest.getPageParams(); if (pageParams.getFrom() != null) { @@ -622,7 +622,7 @@ static Request getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest getS .addPathPartAsIs("_stats") .build(); Request request = new Request(HttpGet.METHOD_NAME, endpoint); - RequestConverters.Params params = new RequestConverters.Params(request); + RequestConverters.Params params = new RequestConverters.Params(); if (getStatsRequest.getPageParams() != null) { PageParams pageParams = getStatsRequest.getPageParams(); if (pageParams.getFrom() != null) { @@ -642,7 +642,7 @@ static Request startDataFrameAnalytics(StartDataFrameAnalyticsRequest startReque .addPathPartAsIs("_start") .build(); Request request = new Request(HttpPost.METHOD_NAME, endpoint); - RequestConverters.Params params = new RequestConverters.Params(request); + RequestConverters.Params params = new RequestConverters.Params(); if (startRequest.getTimeout() != null) { params.withTimeout(startRequest.getTimeout()); } From 62fdff9a0fbede4df30b4cb9ce3c601bf48c8b79 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 30 May 2019 18:05:50 +0300 Subject: [PATCH 52/67] [ML] And add those params the right way --- .../java/org/elasticsearch/client/MLRequestConverters.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index e88658a5fccd5..be6bc6f0d403e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -612,6 +612,7 @@ static Request getDataFrameAnalytics(GetDataFrameAnalyticsRequest getRequest) { params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); } } + request.addParameters(params.asMap()); return request; } @@ -632,6 +633,7 @@ static Request getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest getS params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); } } + request.addParameters(params.asMap()); return request; } @@ -646,6 +648,7 @@ static Request startDataFrameAnalytics(StartDataFrameAnalyticsRequest startReque if (startRequest.getTimeout() != null) { params.withTimeout(startRequest.getTimeout()); } + request.addParameters(params.asMap()); return request; } From ca29685585e0ea8661243a7f3505576ffe5660be Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 31 May 2019 12:57:54 +0300 Subject: [PATCH 53/67] [FEATURE][ML] Data frame analytics stop API (#42624) Adds a stop API for data frame analytics. Co-Authored-By: Benjamin Trent --- .../elasticsearch/xpack/core/ml/MlTasks.java | 4 + .../action/StopDataFrameAnalyticsAction.java | 225 ++++++++++++++++ .../xpack/core/ml/MlTasksTests.java | 7 + .../StartDataFrameAnalyticsRequestTests.java | 2 +- ...DataFrameAnalyticsActionResponseTests.java | 23 ++ .../StopDataFrameAnalyticsRequestTests.java | 40 +++ .../ml/qa/ml-with-security/build.gradle | 4 + ...NativeDataFrameAnalyticsIntegTestCase.java | 6 + .../integration/RunDataFrameAnalyticsIT.java | 54 +++- .../xpack/ml/MachineLearning.java | 10 +- ...ransportStartDataFrameAnalyticsAction.java | 65 ++++- ...TransportStopDataFrameAnalyticsAction.java | 247 ++++++++++++++++++ .../dataframe/DataFrameAnalyticsManager.java | 22 +- .../extractor/DataFrameDataExtractor.java | 1 + .../DataFrameAnalyticsConfigProvider.java | 7 +- .../process/AnalyticsProcessManager.java | 117 +++++++-- .../process/AnalyticsResultProcessor.java | 26 +- .../ml/process/AbstractNativeProcess.java | 3 +- .../xpack/ml/process/MlMemoryTracker.java | 2 +- .../RestStopDataFrameAnalyticsAction.java | 54 ++++ .../AnalyticsResultProcessorTests.java | 15 +- .../ml/process/MlMemoryTrackerTests.java | 2 +- .../api/ml.stop_data_frame_analytics.json | 31 +++ .../test/ml/stop_data_frame_analytics.yml | 66 +++++ 24 files changed, 983 insertions(+), 50 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStopDataFrameAnalyticsAction.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/test/ml/stop_data_frame_analytics.yml diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index 8064abebc296b..9ac63f026b089 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -61,6 +61,10 @@ public static String dataFrameAnalyticsTaskId(String id) { return DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + id; } + public static String dataFrameAnalyticsIdFromTaskId(String taskId) { + return taskId.replaceFirst(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX, ""); + } + @Nullable public static PersistentTasksCustomMetaData.PersistentTask getJobTask(String jobId, @Nullable PersistentTasksCustomMetaData tasks) { return tasks == null ? null : tasks.getTask(jobTaskId(jobId)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..69aa60501fdac --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java @@ -0,0 +1,225 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +public class StopDataFrameAnalyticsAction extends Action { + + public static final StopDataFrameAnalyticsAction INSTANCE = new StopDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/stop"; + + private StopDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return Response::new; + } + + public static class Request extends BaseTasksRequest implements ToXContentObject { + + public static final ParseField TIMEOUT = new ParseField("timeout"); + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); + + static { + PARSER.declareString((request, id) -> request.id = id, DataFrameAnalyticsConfig.ID); + PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); + } + + public static Request parseRequest(String id, XContentParser parser) { + Request request = PARSER.apply(parser, null); + if (request.getId() == null) { + request.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(request.getId())) { + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + request.getId(), id)); + } + return request; + } + + private String id; + private Set expandedIds = Collections.emptySet(); + private boolean allowNoMatch = true; + + public Request(String id) { + setId(id); + } + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + expandedIds = new HashSet<>(Arrays.asList(in.readStringArray())); + allowNoMatch = in.readBoolean(); + } + + public Request() {} + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public String getId() { + return id; + } + + @Nullable + public Set getExpandedIds() { + return expandedIds; + } + + public void setExpandedIds(Set expandedIds) { + this.expandedIds = Objects.requireNonNull(expandedIds); + } + + public boolean allowNoMatch() { + return allowNoMatch; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeStringArray(expandedIds.toArray(new String[0])); + out.writeBoolean(allowNoMatch); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (id != null) { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + } + builder.field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(id, getTimeout(), expandedIds, allowNoMatch); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + StopDataFrameAnalyticsAction.Request other = (StopDataFrameAnalyticsAction.Request) obj; + return Objects.equals(id, other.id) + && Objects.equals(getTimeout(), other.getTimeout()) + && Objects.equals(expandedIds, other.expandedIds) + && allowNoMatch == other.allowNoMatch; + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { + + private final boolean stopped; + + public Response(boolean stopped) { + super(null, null); + this.stopped = stopped; + } + + public Response(StreamInput in) throws IOException { + super(in); + stopped = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(stopped); + } + + public boolean isStopped() { + return stopped; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentCommon(builder, params); + builder.field("stopped", stopped); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Response response = (Response) o; + return stopped == response.stopped; + } + + @Override + public int hashCode() { + return Objects.hash(stopped); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client, StopDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java index 3afe76b8b171f..f2015b1a2bbb5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java @@ -22,6 +22,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; public class MlTasksTests extends ESTestCase { public void testGetJobState() { @@ -161,4 +162,10 @@ public void testUnallocatedDatafeedIds() { assertThat(MlTasks.unallocatedDatafeedIds(tasksBuilder.build(), nodes), containsInAnyOrder("datafeed_without_assignment", "datafeed_without_node")); } + + public void testDataFrameAnalyticsTaskIds() { + String taskId = MlTasks.dataFrameAnalyticsTaskId("foo"); + assertThat(taskId, equalTo("data_frame_analytics-foo")); + assertThat(MlTasks.dataFrameAnalyticsIdFromTaskId(taskId), equalTo("foo")); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java index a7025976134d7..a3db5833b820d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java @@ -18,7 +18,7 @@ protected Request createTestInstance() { if (randomBoolean()) { request.setTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); } - return new Request(randomAlphaOfLength(20)); + return request; } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..d06d8cb1a1860 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction.Response; + +public class StopDataFrameAnalyticsActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + return new Response(randomBoolean()); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..9c61164c5f02a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction.Request; + +import java.util.HashSet; +import java.util.Set; + +public class StopDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + if (randomBoolean()) { + request.setTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + if (randomBoolean()) { + request.setAllowNoMatch(randomBoolean()); + } + int expandedIdsCount = randomIntBetween(0, 10); + Set expandedIds = new HashSet<>(); + for (int i = 0; i < expandedIdsCount; i++) { + expandedIds.add(randomAlphaOfLength(20)); + } + request.setExpandedIds(expandedIds); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index 54e9c4e123593..d73f5a71de3e2 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -123,6 +123,10 @@ integTestRunner { 'ml/start_stop_datafeed/Test start datafeed job, but not open', 'ml/start_stop_datafeed/Test start non existing datafeed', 'ml/start_stop_datafeed/Test stop non existing datafeed', + 'ml/stop_data_frame_analytics/Test stop given missing config and allow_no_match is true', + 'ml/stop_data_frame_analytics/Test stop given missing config and allow_no_match is false', + 'ml/stop_data_frame_analytics/Test stop with expression that does not match and allow_no_match is false', + 'ml/stop_data_frame_analytics/Test stop with inconsistent body/param ids', 'ml/update_model_snapshot/Test without description', 'ml/validate/Test invalid job config', 'ml/validate/Test job config is invalid because model snapshot id set', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index c82f5760c637b..87e723db04896 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; @@ -70,6 +71,11 @@ protected AcknowledgedResponse startAnalytics(String id) { return client().execute(StartDataFrameAnalyticsAction.INSTANCE, request).actionGet(); } + protected StopDataFrameAnalyticsAction.Response stopAnalytics(String id) { + StopDataFrameAnalyticsAction.Request request = new StopDataFrameAnalyticsAction.Request(id); + return client().execute(StopDataFrameAnalyticsAction.INSTANCE, request).actionGet(); + } + protected void waitUntilAnalyticsIsStopped(String id) throws Exception { waitUntilAnalyticsIsStopped(id, TimeValue.timeValueSeconds(30)); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index 6c76fd0cd7af6..1f3899939938e 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.ml.integration; +import org.elasticsearch.action.admin.indices.exists.indices.IndicesExistsRequest; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.get.GetResponse; @@ -30,9 +31,8 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.fail; public class RunDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTestCase { @@ -147,6 +147,56 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); } + public void testStopOutlierDetectionWithEnoughDocumentsToScroll() { + String sourceIndex = "test-outlier-detection-with-enough-docs-to-scroll"; + + client().admin().indices().prepareCreate(sourceIndex) + .addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical_1", "type=keyword") + .get(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + int docCount = randomIntBetween(1024, 2048); + for (int i = 0; i < docCount; i++) { + IndexRequest indexRequest = new IndexRequest(sourceIndex); + indexRequest.source("numeric_1", randomDouble(), "numeric_2", randomFloat(), "categorical_1", randomAlphaOfLength(10)); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + + String id = "test_outlier_detection_with_enough_docs_to_scroll"; + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex, "custom_ml"); + registerAnalytics(config); + putAnalytics(config); + + assertState(id, DataFrameAnalyticsState.STOPPED); + startAnalytics(id); + assertState(id, DataFrameAnalyticsState.STARTED); + + assertThat(stopAnalytics(id).isStopped(), is(true)); + assertState(id, DataFrameAnalyticsState.STOPPED); + + if (client().admin().indices().exists(new IndicesExistsRequest(config.getDest().getIndex())).actionGet().isExists() == false) { + // We stopped before we even created the destination index + return; + } + + SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get(); + if (searchResponse.getHits().getTotalHits().value == docCount) { + searchResponse = client().prepareSearch(config.getDest().getIndex()) + .setTrackTotalHits(true) + .setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get(); + logger.debug("We stopped during analysis: [{}] < [{}]", searchResponse.getHits().getTotalHits().value, docCount); + assertThat(searchResponse.getHits().getTotalHits().value, lessThan((long) docCount)); + } else { + logger.debug("We stopped during reindexing: [{}] < [{}]", searchResponse.getHits().getTotalHits().value, docCount); + } + } + private static DataFrameAnalyticsConfig buildOutlierDetectionAnalytics(String id, String sourceIndex, @Nullable String resultsField) { DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder(id); configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 82b45ede09fbc..da59762ec89eb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -108,6 +108,7 @@ import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateDatafeedAction; @@ -171,6 +172,7 @@ import org.elasticsearch.xpack.ml.action.TransportSetUpgradeModeAction; import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStartDatafeedAction; +import org.elasticsearch.xpack.ml.action.TransportStopDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.action.TransportStopDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportUpdateCalendarJobAction; import org.elasticsearch.xpack.ml.action.TransportUpdateDatafeedAction; @@ -237,6 +239,7 @@ import org.elasticsearch.xpack.ml.rest.dataframe.RestGetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.ml.rest.dataframe.RestPutDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.rest.dataframe.RestStartDataFrameAnalyticsAction; +import org.elasticsearch.xpack.ml.rest.dataframe.RestStopDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.rest.filter.RestDeleteFilterAction; import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction; import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction; @@ -501,8 +504,7 @@ public Collection createComponents(Client client, ClusterService cluster this.datafeedManager.set(datafeedManager); // Data frame analytics components - AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, environment, threadPool, - analyticsProcessFactory); + AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory); DataFrameAnalyticsConfigProvider dataFrameAnalyticsConfigProvider = new DataFrameAnalyticsConfigProvider(client); assert client instanceof NodeClient; DataFrameAnalyticsManager dataFrameAnalyticsManager = new DataFrameAnalyticsManager(clusterService, (NodeClient) client, @@ -556,7 +558,7 @@ public List> getPersistentTasksExecutor(ClusterServic new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get(), memoryTracker.get(), client), new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(datafeedManager.get()), - new TransportStartDataFrameAnalyticsAction.TaskExecutor(settings, clusterService, dataFrameAnalyticsManager.get(), + new TransportStartDataFrameAnalyticsAction.TaskExecutor(settings, client, clusterService, dataFrameAnalyticsManager.get(), memoryTracker.get()) ); } @@ -629,6 +631,7 @@ public List getRestHandlers(Settings settings, RestController restC new RestPutDataFrameAnalyticsAction(settings, restController), new RestDeleteDataFrameAnalyticsAction(settings, restController), new RestStartDataFrameAnalyticsAction(settings, restController), + new RestStopDataFrameAnalyticsAction(settings, restController), new RestEvaluateDataFrameAction(settings, restController) ); } @@ -694,6 +697,7 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(PutDataFrameAnalyticsAction.INSTANCE, TransportPutDataFrameAnalyticsAction.class), new ActionHandler<>(DeleteDataFrameAnalyticsAction.INSTANCE, TransportDeleteDataFrameAnalyticsAction.class), new ActionHandler<>(StartDataFrameAnalyticsAction.INSTANCE, TransportStartDataFrameAnalyticsAction.class), + new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class), new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java index fb10074c25763..3d195f09b26d5 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsAction.java @@ -10,7 +10,10 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.TransportMasterNodeAction; @@ -44,6 +47,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; @@ -250,13 +254,21 @@ public void onFailure(Exception e) { public static class DataFrameAnalyticsTask extends AllocatedPersistentTask implements StartDataFrameAnalyticsAction.TaskMatcher { + private final Client client; + private final ClusterService clusterService; + private final DataFrameAnalyticsManager analyticsManager; private final StartDataFrameAnalyticsAction.TaskParams taskParams; @Nullable private volatile Long reindexingTaskId; + private volatile boolean isStopping; public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map headers, + Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager, StartDataFrameAnalyticsAction.TaskParams taskParams) { super(id, type, action, MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + taskParams.getId(), parentTask, headers); + this.client = Objects.requireNonNull(client); + this.clusterService = Objects.requireNonNull(clusterService); + this.analyticsManager = Objects.requireNonNull(analyticsManager); this.taskParams = Objects.requireNonNull(taskParams); } @@ -264,7 +276,7 @@ public StartDataFrameAnalyticsAction.TaskParams getParams() { return taskParams; } - public void setReindexingTaskId(long reindexingTaskId) { + public void setReindexingTaskId(Long reindexingTaskId) { this.reindexingTaskId = reindexingTaskId; } @@ -272,10 +284,54 @@ public void setReindexingTaskId(long reindexingTaskId) { public Long getReindexingTaskId() { return reindexingTaskId; } + + public boolean isStopping() { + return isStopping; + } + + @Override + protected void onCancelled() { + stop(getReasonCancelled(), TimeValue.ZERO); + } + + public void stop(String reason, TimeValue timeout) { + isStopping = true; + if (reindexingTaskId != null) { + cancelReindexingTask(reason, timeout); + } + analyticsManager.stop(this); + } + + private void cancelReindexingTask(String reason, TimeValue timeout) { + TaskId reindexTaskId = new TaskId(clusterService.localNode().getId(), reindexingTaskId); + LOGGER.debug("[{}] Cancelling reindex task [{}]", taskParams.getId(), reindexTaskId); + + CancelTasksRequest cancelReindex = new CancelTasksRequest(); + cancelReindex.setTaskId(reindexTaskId); + cancelReindex.setReason(reason); + cancelReindex.setTimeout(timeout); + CancelTasksResponse cancelReindexResponse = client.admin().cluster().cancelTasks(cancelReindex).actionGet(); + Throwable firstError = null; + if (cancelReindexResponse.getNodeFailures().isEmpty() == false) { + firstError = cancelReindexResponse.getNodeFailures().get(0).getRootCause(); + } + if (cancelReindexResponse.getTaskFailures().isEmpty() == false) { + firstError = cancelReindexResponse.getTaskFailures().get(0).getCause(); + } + // There is a chance that the task is finished by the time we cancel it in which case we'll get + // a ResourceNotFoundException which we can ignore. + if (firstError != null && firstError instanceof ResourceNotFoundException == false) { + throw ExceptionsHelper.serverError("[" + taskParams.getId() + "] Error cancelling reindex task", firstError); + } else { + LOGGER.debug("[{}] Reindex task was successfully cancelled", taskParams.getId()); + } + } } public static class TaskExecutor extends PersistentTasksExecutor { + private final Client client; + private final ClusterService clusterService; private final DataFrameAnalyticsManager manager; private final MlMemoryTracker memoryTracker; @@ -283,9 +339,11 @@ public static class TaskExecutor extends PersistentTasksExecutor persistentTask, Map headers) { - return new DataFrameAnalyticsTask(id, type, action, parentTaskId, headers, persistentTask.getParams()); + return new DataFrameAnalyticsTask(id, type, action, parentTaskId, headers, client, clusterService, manager, + persistentTask.getParams()); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..7c8222d83f3e3 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStopDataFrameAnalyticsAction.java @@ -0,0 +1,247 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.FailedNodeException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.tasks.TransportTasksAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.discovery.MasterNotDiscoveredException; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Stops the persistent task for running data frame analytics. + * + * TODO Add to the upgrade mode action + */ +public class TransportStopDataFrameAnalyticsAction + extends TransportTasksAction { + + private static final Logger logger = LogManager.getLogger(TransportStopDataFrameAnalyticsAction.class); + + private final ThreadPool threadPool; + private final PersistentTasksService persistentTasksService; + private final DataFrameAnalyticsConfigProvider configProvider; + + @Inject + public TransportStopDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters, + ClusterService clusterService, ThreadPool threadPool, + PersistentTasksService persistentTasksService, + DataFrameAnalyticsConfigProvider configProvider) { + super(StopDataFrameAnalyticsAction.NAME, clusterService, transportService, actionFilters, StopDataFrameAnalyticsAction.Request::new, + StopDataFrameAnalyticsAction.Response::new, StopDataFrameAnalyticsAction.Response::new, ThreadPool.Names.SAME); + this.threadPool = threadPool; + this.persistentTasksService = persistentTasksService; + this.configProvider = configProvider; + } + + @Override + protected void doExecute(Task task, StopDataFrameAnalyticsAction.Request request, + ActionListener listener) { + ClusterState state = clusterService.state(); + DiscoveryNodes nodes = state.nodes(); + if (nodes.isLocalNodeElectedMaster() == false) { + redirectToMasterNode(nodes.getMasterNode(), request, listener); + return; + } + + logger.debug("Received request to stop data frame analytics [{}]", request.getId()); + + ActionListener> expandedIdsListener = ActionListener.wrap( + expandedIds -> { + logger.debug("Resolved data frame analytics to stop: {}", expandedIds); + if (expandedIds.isEmpty()) { + listener.onResponse(new StopDataFrameAnalyticsAction.Response(true)); + return; + } + + Set startedAnalytics = new HashSet<>(); + Set stoppingAnalytics = new HashSet<>(); + PersistentTasksCustomMetaData tasks = state.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + sortAnalyticsByTaskState(expandedIds, tasks, startedAnalytics, stoppingAnalytics); + + request.setExpandedIds(startedAnalytics); + request.setNodes(findAllocatedNodesAndRemoveUnassignedTasks(startedAnalytics, tasks)); + + ActionListener finalListener = ActionListener.wrap( + r -> waitForTaskRemoved(expandedIds, request, r, listener), + listener::onFailure + ); + + super.doExecute(task, request, finalListener); + }, + listener::onFailure + ); + + expandIds(state, request, expandedIdsListener); + } + + private static void sortAnalyticsByTaskState(Set analyticsIds, PersistentTasksCustomMetaData tasks, + Set startedAnalytics, Set stoppingAnalytics) { + for (String analyticsId : analyticsIds) { + switch (MlTasks.getDataFrameAnalyticsState(analyticsId, tasks)) { + case STARTED: + case REINDEXING: + case ANALYZING: + startedAnalytics.add(analyticsId); + break; + case STOPPING: + stoppingAnalytics.add(analyticsId); + break; + case STOPPED: + break; + default: + break; + } + } + } + + private void expandIds(ClusterState clusterState, StopDataFrameAnalyticsAction.Request request, + ActionListener> expandedIdsListener) { + ActionListener> configsListener = ActionListener.wrap( + configs -> { + Set matchingIds = configs.stream().map(DataFrameAnalyticsConfig::getId).collect(Collectors.toSet()); + PersistentTasksCustomMetaData tasksMetaData = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + Set startedIds = tasksMetaData == null ? Collections.emptySet() : tasksMetaData.tasks().stream() + .filter(t -> t.getId().startsWith(MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX)) + .map(t -> t.getId().replaceFirst(MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX, "")) + .collect(Collectors.toSet()); + startedIds.retainAll(matchingIds); + expandedIdsListener.onResponse(startedIds); + }, + expandedIdsListener::onFailure + ); + + configProvider.getMultiple(request.getId(), request.allowNoMatch(), configsListener); + } + + private String[] findAllocatedNodesAndRemoveUnassignedTasks(Set analyticsIds, PersistentTasksCustomMetaData tasks) { + List nodes = new ArrayList<>(); + for (String analyticsId : analyticsIds) { + PersistentTasksCustomMetaData.PersistentTask task = MlTasks.getDataFrameAnalyticsTask(analyticsId, tasks); + if (task == null) { + // This should not be possible; we filtered started analytics thus the task should exist + String msg = "Requested data frame analytics [" + analyticsId + "] be stopped but the task could not be found"; + assert task != null : msg; + } else if (task.isAssigned()) { + nodes.add(task.getExecutorNode()); + } else { + // This means the task has not been assigned to a node yet so + // we can stop it by removing its persistent task. + // The listener is a no-op as we're already going to wait for the task to be removed. + persistentTasksService.sendRemoveRequest(task.getId(), ActionListener.wrap(r -> {}, e -> {})); + } + } + return nodes.toArray(new String[0]); + } + + private void redirectToMasterNode(DiscoveryNode masterNode, StopDataFrameAnalyticsAction.Request request, + ActionListener listener) { + if (masterNode == null) { + listener.onFailure(new MasterNotDiscoveredException("no known master node")); + } else { + transportService.sendRequest(masterNode, actionName, request, + new ActionListenerResponseHandler<>(listener, StopDataFrameAnalyticsAction.Response::new)); + } + } + + @Override + protected StopDataFrameAnalyticsAction.Response newResponse(StopDataFrameAnalyticsAction.Request request, + List tasks, + List taskOperationFailures, + List failedNodeExceptions) { + if (request.getExpandedIds().size() != tasks.size()) { + if (taskOperationFailures.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(taskOperationFailures.get(0).getCause()); + } else if (failedNodeExceptions.isEmpty() == false) { + throw org.elasticsearch.ExceptionsHelper.convertToElastic(failedNodeExceptions.get(0)); + } else { + // This can happen when the actual task in the node no longer exists, + // which means the data frame analytic(s) have already been closed. + return new StopDataFrameAnalyticsAction.Response(true); + } + } + return new StopDataFrameAnalyticsAction.Response(tasks.stream().allMatch(StopDataFrameAnalyticsAction.Response::isStopped)); + } + + @Override + protected void taskOperation(StopDataFrameAnalyticsAction.Request request, + TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask task, + ActionListener listener) { + DataFrameAnalyticsTaskState stoppingState = + new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.STOPPING, task.getAllocationId()); + task.updatePersistentTaskState(stoppingState, ActionListener.wrap(pTask -> { + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + + @Override + protected void doRun() { + task.stop("stop_data_frame_analytics (api)", request.getTimeout()); + listener.onResponse(new StopDataFrameAnalyticsAction.Response(true)); + } + }); + }, + e -> { + if (e instanceof ResourceNotFoundException) { + // the task has disappeared so must have stopped + listener.onResponse(new StopDataFrameAnalyticsAction.Response(true)); + } else { + listener.onFailure(e); + } + })); + } + + void waitForTaskRemoved(Set analyticsIds, StopDataFrameAnalyticsAction.Request request, + StopDataFrameAnalyticsAction.Response response, + ActionListener listener) { + persistentTasksService.waitForPersistentTasksCondition(persistentTasks -> + filterPersistentTasks(persistentTasks, analyticsIds).isEmpty(), + request.getTimeout(), ActionListener.wrap( + booleanResponse -> listener.onResponse(response), + listener::onFailure + )); + } + + private static Collection> filterPersistentTasks( + PersistentTasksCustomMetaData persistentTasks, Set analyticsIds) { + return persistentTasks.findTasks(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, + t -> analyticsIds.contains(MlTasks.dataFrameAnalyticsIdFromTaskId(t.getId()))); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 02369adfd785d..c0f394b6dc890 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -137,9 +137,18 @@ public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState current } private void reindexDataframeAndStartAnalysis(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config) { + if (task.isStopping()) { + // The task was requested to stop before we started reindexing + task.markAsCompleted(); + return; + } + // Reindexing is complete; start analytics ActionListener refreshListener = ActionListener.wrap( - refreshResponse -> startAnalytics(task, config, false), + refreshResponse -> { + task.setReindexingTaskId(null); + startAnalytics(task, config, false); + }, task::markAsFailed ); @@ -184,7 +193,7 @@ private void startAnalytics(DataFrameAnalyticsTask task, DataFrameAnalyticsConfi DataFrameAnalyticsTaskState analyzingState = new DataFrameAnalyticsTaskState(DataFrameAnalyticsState.ANALYZING, task.getAllocationId()); task.updatePersistentTaskState(analyzingState, ActionListener.wrap( - updatedTask -> processManager.runJob(task.getAllocationId(), config, dataExtractorFactory, + updatedTask -> processManager.runJob(task, config, dataExtractorFactory, error -> { if (error != null) { task.markAsFailed(error); @@ -213,7 +222,7 @@ private void createDestinationIndex(String sourceIndex, String destinationIndex, } Settings.Builder settingsBuilder = Settings.builder().put(indexMetaData.getSettings()); - INTERNAL_SETTINGS.stream().forEach(settingsBuilder::remove); + INTERNAL_SETTINGS.forEach(settingsBuilder::remove); settingsBuilder.put(IndexSortConfig.INDEX_SORT_FIELD_SETTING.getKey(), DataFrameAnalyticsFields.ID); settingsBuilder.put(IndexSortConfig.INDEX_SORT_ORDER_SETTING.getKey(), SortOrder.ASC); @@ -230,11 +239,18 @@ private void createDestinationIndex(String sourceIndex, String destinationIndex, private static void addDestinationIndexMappings(IndexMetaData indexMetaData, CreateIndexRequest createIndexRequest) { ImmutableOpenMap mappings = indexMetaData.getMappings(); Map mappingsAsMap = mappings.valuesIt().next().sourceAsMap(); + + @SuppressWarnings("unchecked") Map properties = (Map) mappingsAsMap.get("properties"); + Map idCopyMapping = new HashMap<>(); idCopyMapping.put("type", "keyword"); properties.put(DataFrameAnalyticsFields.ID, idCopyMapping); createIndexRequest.mapping(mappings.keysIt().next(), mappingsAsMap); } + + public void stop(DataFrameAnalyticsTask task) { + processManager.stop(task); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 7f01f800d1d71..a45185ebe213f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -74,6 +74,7 @@ public boolean isCancelled() { } public void cancel() { + LOGGER.debug("[{}] Data extractor was cancelled", context.jobId); isCancelled = true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java index 4aefd57fb0eae..569469452cf64 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/persistence/DataFrameAnalyticsConfigProvider.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; @@ -37,6 +38,8 @@ public class DataFrameAnalyticsConfigProvider { + private static final int MAX_CONFIGS_SIZE = 10000; + private static final Map TO_XCONTENT_PARAMS; static { @@ -108,9 +111,11 @@ public void get(String id, ActionListener listener) { /** * @param ids a comma separated list of single IDs and/or wildcards */ - public void getMultiple(String ids, ActionListener> listener) { + public void getMultiple(String ids, boolean allowNoMatch, ActionListener> listener) { GetDataFrameAnalyticsAction.Request request = new GetDataFrameAnalyticsAction.Request(); + request.setPageParams(new PageParams(0, MAX_CONFIGS_SIZE)); request.setResourceId(ids); + request.setAllowNoResources(allowNoMatch); executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, request, ActionListener.wrap( response -> listener.onResponse(response.getResources().results()), listener::onFailure)); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 52b308a3aa98e..c1447f4d18b42 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -12,12 +12,12 @@ import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; -import org.elasticsearch.env.Environment; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.action.TransportStartDataFrameAnalyticsAction; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; @@ -40,27 +40,35 @@ public class AnalyticsProcessManager { private final AnalyticsProcessFactory processFactory; private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); - public AnalyticsProcessManager(Client client, Environment environment, ThreadPool threadPool, - AnalyticsProcessFactory analyticsProcessFactory) { + public AnalyticsProcessManager(Client client, ThreadPool threadPool, AnalyticsProcessFactory analyticsProcessFactory) { this.client = Objects.requireNonNull(client); this.threadPool = Objects.requireNonNull(threadPool); this.processFactory = Objects.requireNonNull(analyticsProcessFactory); } - public void runJob(long taskAllocationId, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory, - Consumer finishHandler) { + public void runJob(TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, + DataFrameDataExtractorFactory dataExtractorFactory, Consumer finishHandler) { threadPool.generic().execute(() -> { - DataFrameDataExtractor dataExtractor = dataExtractorFactory.newExtractor(false); - AnalyticsProcess process = createProcess(config.getId(), createProcessConfig(config, dataExtractor)); - processContextByAllocation.putIfAbsent(taskAllocationId, new ProcessContext()); - ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); - DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, - dataExtractorFactory.newExtractor(true)); - AnalyticsResultProcessor resultProcessor = new AnalyticsResultProcessor(processContextByAllocation.get(taskAllocationId), - dataFrameRowsJoiner); - executorService.execute(() -> resultProcessor.process(process)); - executorService.execute( - () -> processData(taskAllocationId, config, dataExtractor, process, resultProcessor, finishHandler)); + if (task.isStopping()) { + // The task was requested to stop before we created the process context + finishHandler.accept(null); + return; + } + + ProcessContext processContext = new ProcessContext(config.getId()); + if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) { + finishHandler.accept(ExceptionsHelper.serverError("[" + processContext.id + + "] Could not create process as one already exists")); + return; + } + if (processContext.startProcess(dataExtractorFactory, config)) { + ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); + executorService.execute(() -> processContext.resultProcessor.process(processContext.process)); + executorService.execute(() -> processData(task.getAllocationId(), config, processContext.dataExtractor, + processContext.process, processContext.resultProcessor, finishHandler)); + } else { + finishHandler.accept(null); + } }); } @@ -92,6 +100,8 @@ private void processData(long taskAllocationId, DataFrameAnalyticsConfig config, finishHandler.accept(e); } processContextByAllocation.remove(taskAllocationId); + LOGGER.debug("Removed process context for task [{}]; [{}] processes still running", config.getId(), + processContextByAllocation.size()); } } @@ -141,13 +151,6 @@ private AnalyticsProcess createProcess(String jobId, AnalyticsProcessConfig anal return process; } - private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { - DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); - AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, - config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), config.getAnalysis()); - return processConfig; - } - @Nullable public Integer getProgressPercent(long allocationId) { ProcessContext processContext = processContextByAllocation.get(allocationId); @@ -159,12 +162,78 @@ private void refreshDest(DataFrameAnalyticsConfig config) { () -> client.execute(RefreshAction.INSTANCE, new RefreshRequest(config.getDest().getIndex())).actionGet()); } - static class ProcessContext { + public void stop(TransportStartDataFrameAnalyticsAction.DataFrameAnalyticsTask task) { + ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); + if (processContext != null) { + LOGGER.debug("[{}] Stopping process", task.getParams().getId() ); + processContext.stop(); + } else { + LOGGER.debug("[{}] No process context to stop", task.getParams().getId() ); + } + } + + class ProcessContext { + private final String id; + private volatile AnalyticsProcess process; + private volatile DataFrameDataExtractor dataExtractor; + private volatile AnalyticsResultProcessor resultProcessor; private final AtomicInteger progressPercent = new AtomicInteger(0); + private volatile boolean processKilled; + + ProcessContext(String id) { + this.id = Objects.requireNonNull(id); + } + + public String getId() { + return id; + } + + public boolean isProcessKilled() { + return processKilled; + } void setProgressPercent(int progressPercent) { this.progressPercent.set(progressPercent); } + + public synchronized void stop() { + LOGGER.debug("[{}] Stopping process", id); + processKilled = true; + if (dataExtractor != null) { + dataExtractor.cancel(); + } + if (process != null) { + try { + process.kill(); + } catch (IOException e) { + LOGGER.error(new ParameterizedMessage("[{}] Failed to kill process", id), e); + } + } + } + + /** + * @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime + */ + private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config) { + if (processKilled) { + // The job was stopped before we started the process so no need to start it + return false; + } + + dataExtractor = dataExtractorFactory.newExtractor(false); + process = createProcess(config.getId(), createProcessConfig(config, dataExtractor)); + DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, + dataExtractorFactory.newExtractor(true)); + resultProcessor = new AnalyticsResultProcessor(id, dataFrameRowsJoiner, this::isProcessKilled, this::setProgressPercent); + return true; + } + + private AnalyticsProcessConfig createProcessConfig(DataFrameAnalyticsConfig config, DataFrameDataExtractor dataExtractor) { + DataFrameDataExtractor.DataSummary dataSummary = dataExtractor.collectDataSummary(); + AnalyticsProcessConfig processConfig = new AnalyticsProcessConfig(dataSummary.rows, dataSummary.cols, + config.getModelMemoryLimit(), 1, config.getDest().getResultsField(), config.getAnalysis()); + return processConfig; + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 7e204e52e0c4a..f9b131393541a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -7,34 +7,42 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import java.util.Iterator; import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.function.Supplier; public class AnalyticsResultProcessor { private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); - private final AnalyticsProcessManager.ProcessContext processContext; + private final String dataFrameAnalyticsId; private final DataFrameRowsJoiner dataFrameRowsJoiner; + private final Supplier isProcessKilled; + private final Consumer progressConsumer; private final CountDownLatch completionLatch = new CountDownLatch(1); - public AnalyticsResultProcessor(AnalyticsProcessManager.ProcessContext processContext, DataFrameRowsJoiner dataFrameRowsJoiner) { - this.processContext = Objects.requireNonNull(processContext); + public AnalyticsResultProcessor(String dataFrameAnalyticsId, DataFrameRowsJoiner dataFrameRowsJoiner, Supplier isProcessKilled, + Consumer progressConsumer) { + this.dataFrameAnalyticsId = Objects.requireNonNull(dataFrameAnalyticsId); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); + this.isProcessKilled = Objects.requireNonNull(isProcessKilled); + this.progressConsumer = Objects.requireNonNull(progressConsumer); } public void awaitForCompletion() { try { if (completionLatch.await(30, TimeUnit.MINUTES) == false) { - LOGGER.warn("Timeout waiting for results processor to complete"); + LOGGER.warn("[{}] Timeout waiting for results processor to complete", dataFrameAnalyticsId); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); - LOGGER.info("Interrupted waiting for results processor to complete"); + LOGGER.info("[{}] Interrupted waiting for results processor to complete", dataFrameAnalyticsId); } } @@ -47,7 +55,11 @@ public void process(AnalyticsProcess process) { processResult(result, resultsJoiner); } } catch (Exception e) { - LOGGER.error("Error parsing data frame analytics output", e); + if (isProcessKilled.get()) { + // No need to log error as it's due to stopping + } else { + LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", dataFrameAnalyticsId), e); + } } finally { completionLatch.countDown(); process.consumeAndCloseOutputStream(); @@ -61,7 +73,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } Integer progressPercent = result.getProgressPercent(); if (progressPercent != null) { - processContext.setProgressPercent(progressPercent); + progressConsumer.accept(progressPercent); } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java index c270de69e79ea..60673467ba0e4 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/AbstractNativeProcess.java @@ -177,6 +177,7 @@ public void close() throws IOException { @Override public void kill() throws IOException { + LOGGER.debug("[{}] Killing {} process", jobId, getName()); processKilled = true; try { // The PID comes via the processes log stream. We don't wait for it to arrive here, @@ -274,7 +275,7 @@ public void consumeAndCloseOutputStream() { } processOutStream().close(); } catch (IOException e) { - throw new RuntimeException("Error closing result parser input stream", e); + // Given we are closing down the process there is no point propagating IO exceptions here } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java index 09997406955da..afd670a180384 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java @@ -344,7 +344,7 @@ private void refreshAllDataFrameAnalyticsJobTasks(List ((StartDataFrameAnalyticsAction.TaskParams) task.getParams()).getId()).sorted().collect(Collectors.joining(",")); - configProvider.getMultiple(startedJobIds, ActionListener.wrap( + configProvider.getMultiple(startedJobIds, false, ActionListener.wrap( analyticsConfigs -> { for (DataFrameAnalyticsConfig analyticsConfig : analyticsConfigs) { memoryRequirementByDataFrameAnalyticsJob.put(analyticsConfig.getId(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStopDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStopDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..8a399c736c92e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestStopDataFrameAnalyticsAction.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.rest.dataframe; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +import java.io.IOException; + +public class RestStopDataFrameAnalyticsAction extends BaseRestHandler { + + public RestStopDataFrameAnalyticsAction(Settings settings, RestController controller) { + super(settings); + controller.registerHandler(RestRequest.Method.POST, MachineLearning.BASE_PATH + "data_frame/analytics/{" + + DataFrameAnalyticsConfig.ID.getPreferredName() + "}/_stop", this); + } + + @Override + public String getName() { + return "xpack_ml_stop_data_frame_analytics_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String id = restRequest.param(DataFrameAnalyticsConfig.ID.getPreferredName()); + StopDataFrameAnalyticsAction.Request request; + if (restRequest.hasContentOrSourceParam()) { + request = StopDataFrameAnalyticsAction.Request.parseRequest(id, restRequest.contentOrSourceParamParser()); + } else { + request = new StopDataFrameAnalyticsAction.Request(id); + if (restRequest.hasParam(StopDataFrameAnalyticsAction.Request.TIMEOUT.getPreferredName())) { + TimeValue timeout = restRequest.paramAsTime(StopDataFrameAnalyticsAction.Request.TIMEOUT.getPreferredName(), + request.getTimeout()); + request.setTimeout(timeout); + } + if (restRequest.hasParam(StopDataFrameAnalyticsAction.Request.ALLOW_NO_MATCH.getPreferredName())) { + request.setAllowNoMatch(restRequest.paramAsBoolean(StopDataFrameAnalyticsAction.Request.ALLOW_NO_MATCH.getPreferredName(), + request.allowNoMatch())); + } + } + return channel -> client.execute(StopDataFrameAnalyticsAction.INSTANCE, request, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index e3f4cf6ebc9f7..4032f2d65bf34 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -15,6 +15,7 @@ import java.util.Collections; import java.util.List; +import static org.hamcrest.Matchers.equalTo; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -22,8 +23,12 @@ public class AnalyticsResultProcessorTests extends ESTestCase { + private static final String JOB_ID = "analytics-result-processor-tests"; + private AnalyticsProcess process; private DataFrameRowsJoiner dataFrameRowsJoiner; + private int progressPercent; + @Before public void setUpMocks() { @@ -43,7 +48,7 @@ public void testProcess_GivenNoResults() { } public void testProcess_GivenEmptyResults() { - givenProcessResults(Arrays.asList(new AnalyticsResult(null, null), new AnalyticsResult(null, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(null, 50), new AnalyticsResult(null, 100))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -51,12 +56,13 @@ public void testProcess_GivenEmptyResults() { verify(dataFrameRowsJoiner).close(); Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner); + assertThat(progressPercent, equalTo(100)); } public void testProcess_GivenRowResults() { RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null), new AnalyticsResult(rowResults2, null))); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50), new AnalyticsResult(rowResults2, 100))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -65,6 +71,8 @@ public void testProcess_GivenRowResults() { InOrder inOrder = Mockito.inOrder(dataFrameRowsJoiner); inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1); inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2); + + assertThat(progressPercent, equalTo(100)); } private void givenProcessResults(List results) { @@ -72,6 +80,7 @@ private void givenProcessResults(List results) { } private AnalyticsResultProcessor createResultProcessor() { - return new AnalyticsResultProcessor(new AnalyticsProcessManager.ProcessContext(), dataFrameRowsJoiner); + return new AnalyticsResultProcessor(JOB_ID, dataFrameRowsJoiner, () -> false, + progressPercent -> this.progressPercent = progressPercent); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java index 426a9e0f83984..1dea073123ad2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java @@ -118,7 +118,7 @@ public void testRefreshAll() { String jobId = "job" + i; verify(jobResultsProvider, times(1)).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(), any()); } - verify(configProvider, times(1)).getMultiple(eq(String.join(",", allIds)), any(ActionListener.class)); + verify(configProvider, times(1)).getMultiple(eq(String.join(",", allIds)), eq(false), any(ActionListener.class)); } else { verify(jobResultsProvider, never()).getEstablishedMemoryUsage(anyString(), any(), any(), any(), any()); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json new file mode 100644 index 0000000000000..cc95def45fce6 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json @@ -0,0 +1,31 @@ +{ + "ml.stop_data_frame_analytics": { + "methods": [ "POST" ], + "url": { + "path": "/_ml/data_frame/analytics/{id}/_stop", + "paths": [ "/_ml/data_frame/analytics/{id}/_stop" ], + "parts": { + "id": { + "type": "string", + "required": true, + "description": "The ID of the data frame analytics to stop" + } + }, + "params": { + "allow_no_match": { + "type": "boolean", + "required": false, + "description": "Whether to ignore if a wildcard expression matches no data frame analytics. (This includes `_all` string or when no data frame analytics have been specified)" + }, + "timeout": { + "type": "time", + "required": false, + "description": "Controls the time to wait until the task has stopped. Defaults to 20 seconds" + } + } + }, + "body": { + "description": "The stop data frame analytics parameters" + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/stop_data_frame_analytics.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/stop_data_frame_analytics.yml new file mode 100644 index 0000000000000..673910bce9fdc --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/stop_data_frame_analytics.yml @@ -0,0 +1,66 @@ +setup: + - do: + ml.put_data_frame_analytics: + id: "stop_data_frame_analytics_test_job" + body: > + { + "source": { + "index": "index-source" + }, + "dest": { + "index": "index-dest" + }, + "analysis": {"outlier_detection":{}} + } + +--- +"Test stop given missing config and allow_no_match is true": + - do: + catch: missing + ml.stop_data_frame_analytics: + id: "missing_config" + allow_no_match: true + +--- +"Test stop given missing config and allow_no_match is false": + - do: + catch: missing + ml.stop_data_frame_analytics: + id: "missing_config" + allow_no_match: false + +--- +"Test stop with expression that does not match and allow_no_match is true": + - do: + ml.stop_data_frame_analytics: + id: "missing-*" + allow_no_match: true + - match: { stopped: true } + +--- +"Test stop with expression that does not match and allow_no_match is false": + - do: + catch: missing + ml.stop_data_frame_analytics: + id: "missing-*" + allow_no_match: false + +--- +"Test stop given stopped": + + - do: + ml.stop_data_frame_analytics: + id: "stop_data_frame_analytics_test_job" + - match: { stopped: true } + +--- +"Test stop with inconsistent body/param ids": + + - do: + catch: /Inconsistent id; 'body_id' specified in the body differs from 'url_id' specified as a URL argument/ + ml.stop_data_frame_analytics: + id: "url_id" + body: > + { + "id": "body_id" + } From 626cae5ac03c46a64a47f900b03e8946e1956b02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Fri, 31 May 2019 14:09:39 +0200 Subject: [PATCH 54/67] [ML][FEATURE] Implement client-side evaluation API (#42058) Implement client-side evaluation API --- .../client/MLRequestConverters.java | 10 + .../client/MachineLearningClient.java | 42 +++ .../client/ml/EvaluateDataFrameRequest.java | 136 ++++++++++ .../client/ml/EvaluateDataFrameResponse.java | 117 +++++++++ .../ml/dataframe/evaluation/Evaluation.java | 32 +++ .../evaluation/EvaluationMetric.java | 43 ++++ .../MlEvaluationNamedXContentProvider.java | 57 +++++ .../AbstractConfusionMatrixMetric.java | 47 ++++ .../softclassification/AucRocMetric.java | 241 ++++++++++++++++++ .../BinarySoftClassification.java | 129 ++++++++++ .../ConfusionMatrixMetric.java | 206 +++++++++++++++ .../softclassification/PrecisionMetric.java | 123 +++++++++ .../softclassification/RecallMetric.java | 123 +++++++++ ...icsearch.plugins.spi.NamedXContentProvider | 3 +- .../client/MLRequestConvertersTests.java | 23 ++ .../client/MachineLearningIT.java | 140 +++++++++- .../client/RestHighLevelClientTests.java | 16 +- .../ml/AucRocMetricAucRocPointTests.java | 47 ++++ .../client/ml/AucRocMetricResultTests.java | 63 +++++ ...usionMatrixMetricConfusionMatrixTests.java | 47 ++++ .../ml/ConfusionMatrixMetricResultTests.java | 62 +++++ .../ml/EvaluateDataFrameResponseTests.java | 76 ++++++ .../client/ml/PrecisionMetricResultTests.java | 60 +++++ .../client/ml/RecallMetricResultTests.java | 60 +++++ 24 files changed, 1886 insertions(+), 17 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index be6bc6f0d403e..11c7c1a4cc40e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -39,6 +39,7 @@ import org.elasticsearch.client.ml.DeleteForecastRequest; import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FlushJobRequest; import org.elasticsearch.client.ml.ForecastJobRequest; @@ -660,6 +661,15 @@ static Request deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest deleteRe return new Request(HttpDelete.METHOD_NAME, endpoint); } + static Request evaluateDataFrame(EvaluateDataFrameRequest evaluateRequest) throws IOException { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "_evaluate") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + request.setEntity(createEntity(evaluateRequest, REQUEST_BODY_CONTENT_TYPE)); + return request; + } + static Request putFilter(PutFilterRequest putFilterRequest) throws IOException { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 215a674a88aad..002ebd1c45bfc 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -34,6 +34,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -2087,4 +2089,44 @@ public void deleteDataFrameAnalyticsAsync(DeleteDataFrameAnalyticsRequest reques listener, Collections.emptySet()); } + + /** + * Evaluates the given Data Frame + *

+ * For additional info + * see Evaluate Data Frame documentation + * + * @param request The {@link EvaluateDataFrameRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link EvaluateDataFrameResponse} response object + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public EvaluateDataFrameResponse evaluateDataFrame(EvaluateDataFrameRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::evaluateDataFrame, + options, + EvaluateDataFrameResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Evaluates the given Data Frame asynchronously and notifies listener upon completion + *

+ * For additional info + * see Evaluate Data Frame documentation + * + * @param request The {@link EvaluateDataFrameRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void evaluateDataFrameAsync(EvaluateDataFrameRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::evaluateDataFrame, + options, + EvaluateDataFrameResponse::fromXContent, + listener, + Collections.emptySet()); + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java new file mode 100644 index 0000000000000..2e3bbb170509c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java @@ -0,0 +1,136 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +public class EvaluateDataFrameRequest implements ToXContentObject, Validatable { + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField EVALUATION = new ParseField("evaluation"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "evaluate_data_frame_request", true, args -> new EvaluateDataFrameRequest((List) args[0], (Evaluation) args[1])); + + static { + PARSER.declareStringArray(constructorArg(), INDEX); + PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); + } + + private static Evaluation parseEvaluation(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + Evaluation evaluation = parser.namedObject(Evaluation.class, parser.currentName(), null); + ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return evaluation; + } + + public static EvaluateDataFrameRequest fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private List indices; + private Evaluation evaluation; + + public EvaluateDataFrameRequest(String index, Evaluation evaluation) { + this(Arrays.asList(index), evaluation); + } + + public EvaluateDataFrameRequest(List indices, Evaluation evaluation) { + setIndices(indices); + setEvaluation(evaluation); + } + + public List getIndices() { + return Collections.unmodifiableList(indices); + } + + public final void setIndices(List indices) { + Objects.requireNonNull(indices); + this.indices = new ArrayList<>(indices); + } + + public Evaluation getEvaluation() { + return evaluation; + } + + public final void setEvaluation(Evaluation evaluation) { + this.evaluation = evaluation; + } + + @Override + public Optional validate() { + List errors = new ArrayList<>(); + if (indices.isEmpty()) { + errors.add("At least one index must be specified"); + } + if (evaluation == null) { + errors.add("evaluation must not be null"); + } + return errors.isEmpty() + ? Optional.empty() + : Optional.of(ValidationException.withErrors(errors)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .array(INDEX.getPreferredName(), indices.toArray()) + .startObject(EVALUATION.getPreferredName()) + .field(evaluation.getName(), evaluation) + .endObject() + .endObject(); + } + + @Override + public int hashCode() { + return Objects.hash(indices, evaluation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EvaluateDataFrameRequest that = (EvaluateDataFrameRequest) o; + return Objects.equals(indices, that.indices) + && Objects.equals(evaluation, that.evaluation); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java new file mode 100644 index 0000000000000..0709021ed4bd5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java @@ -0,0 +1,117 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.NamedObjectNotFoundException; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +public class EvaluateDataFrameResponse implements ToXContentObject { + + public static EvaluateDataFrameResponse fromXContent(XContentParser parser) throws IOException { + if (parser.currentToken() == null) { + parser.nextToken(); + } + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + String evaluationName = parser.currentName(); + parser.nextToken(); + Map metrics = parser.map(LinkedHashMap::new, EvaluateDataFrameResponse::parseMetric); + List knownMetrics = + metrics.values().stream() + .filter(Objects::nonNull) // Filter out null values returned by {@link EvaluateDataFrameResponse::parseMetric}. + .collect(Collectors.toList()); + ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return new EvaluateDataFrameResponse(evaluationName, knownMetrics); + } + + private static EvaluationMetric.Result parseMetric(XContentParser parser) throws IOException { + String metricName = parser.currentName(); + try { + return parser.namedObject(EvaluationMetric.Result.class, metricName, null); + } catch (NamedObjectNotFoundException e) { + parser.skipChildren(); + // Metric name not recognized. Return {@code null} value here and filter it out later. + return null; + } + } + + private final String evaluationName; + private final Map metrics; + + public EvaluateDataFrameResponse(String evaluationName, List metrics) { + this.evaluationName = Objects.requireNonNull(evaluationName); + this.metrics = Objects.requireNonNull(metrics).stream().collect(Collectors.toUnmodifiableMap(m -> m.getMetricName(), m -> m)); + } + + public String getEvaluationName() { + return evaluationName; + } + + public List getMetrics() { + return metrics.values().stream().collect(Collectors.toList()); + } + + @SuppressWarnings("unchecked") + public T getMetricByName(String metricName) { + Objects.requireNonNull(metricName); + return (T) metrics.get(metricName); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(evaluationName, metrics) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + EvaluateDataFrameResponse that = (EvaluateDataFrameResponse) o; + return Objects.equals(evaluationName, that.evaluationName) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(evaluationName, metrics); + } + + @Override + public final String toString() { + return Strings.toString(this); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java new file mode 100644 index 0000000000000..78578597e195b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java @@ -0,0 +1,32 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Defines an evaluation + */ +public interface Evaluation extends ToXContentObject { + + /** + * Returns the evaluation name + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java new file mode 100644 index 0000000000000..a0f77838f1fd0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Defines an evaluation metric + */ +public interface EvaluationMetric extends ToXContentObject { + + /** + * Returns the name of the metric + */ + String getName(); + + /** + * The result of an evaluation metric + */ + interface Result extends ToXContentObject { + + /** + * Returns the name of the metric + */ + String getMetricName(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java new file mode 100644 index 0000000000000..764ff41de86e0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; + +import java.util.Arrays; +import java.util.List; + +public class MlEvaluationNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + // Evaluations + new NamedXContentRegistry.Entry( + Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent), + // Evaluation metrics + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + // Evaluation metrics results + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java new file mode 100644 index 0000000000000..f41c13f248ab9 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +abstract class AbstractConfusionMatrixMetric implements EvaluationMetric { + + protected static final ParseField AT = new ParseField("at"); + + protected final double[] thresholds; + + protected AbstractConfusionMatrixMetric(List at) { + this.thresholds = Objects.requireNonNull(at).stream().mapToDouble(Double::doubleValue).toArray(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(AT.getPreferredName(), thresholds) + .endObject(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java new file mode 100644 index 0000000000000..78c713c592581 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java @@ -0,0 +1,241 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + */ +public class AucRocMetric implements EvaluationMetric { + + public static final String NAME = "auc_roc"; + + public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new AucRocMetric((Boolean) args[0])); + + static { + PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE); + } + + public static AucRocMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static AucRocMetric withCurve() { + return new AucRocMetric(true); + } + + private final boolean includeCurve; + + public AucRocMetric(Boolean includeCurve) { + this.includeCurve = includeCurve == null ? false : includeCurve; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(INCLUDE_CURVE.getPreferredName(), includeCurve) + .endObject(); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocMetric that = (AucRocMetric) o; + return Objects.equals(includeCurve, that.includeCurve); + } + + @Override + public int hashCode() { + return Objects.hash(includeCurve); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField CURVE = new ParseField("curve"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("auc_roc_result", true, args -> new Result((double) args[0], (List) args[1])); + + static { + PARSER.declareDouble(constructorArg(), SCORE); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); + } + + private final double score; + private final List curve; + + public Result(double score, @Nullable List curve) { + this.score = score; + this.curve = curve; + } + + @Override + public String getMetricName() { + return NAME; + } + + public double getScore() { + return score; + } + + public List getCurve() { + return curve == null ? null : Collections.unmodifiableList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(SCORE.getPreferredName(), score); + if (curve != null && curve.isEmpty() == false) { + builder.field(CURVE.getPreferredName(), curve); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(score, that.score) + && Objects.equals(curve, that.curve); + } + + @Override + public int hashCode() { + return Objects.hash(score, curve); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static final class AucRocPoint implements ToXContentObject { + + public static AucRocPoint fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField TPR = new ParseField("tpr"); + private static final ParseField FPR = new ParseField("fpr"); + private static final ParseField THRESHOLD = new ParseField("threshold"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "auc_roc_point", + true, + args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2])); + + static { + PARSER.declareDouble(constructorArg(), TPR); + PARSER.declareDouble(constructorArg(), FPR); + PARSER.declareDouble(constructorArg(), THRESHOLD); + } + + private final double tpr; + private final double fpr; + private final double threshold; + + public AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + public double getTruePositiveRate() { + return tpr; + } + + public double getFalsePositiveRate() { + return fpr; + } + + public double getThreshold() { + return threshold; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(TPR.getPreferredName(), tpr) + .field(FPR.getPreferredName(), fpr) + .field(THRESHOLD.getPreferredName(), threshold) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocPoint that = (AucRocPoint) o; + return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold; + } + + @Override + public int hashCode() { + return Objects.hash(tpr, fpr, threshold); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java new file mode 100644 index 0000000000000..6d5fa04da38e5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -0,0 +1,129 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Evaluation of binary soft classification methods, e.g. outlier detection. + * This is useful to evaluate problems where a model outputs a probability of whether + * a data frame row belongs to one of two groups. + */ +public class BinarySoftClassification implements Evaluation { + + public static final String NAME = "binary_soft_classification"; + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + NAME, + args -> new BinarySoftClassification((String) args[0], (String) args[1], (List) args[2])); + + static { + PARSER.declareString(constructorArg(), ACTUAL_FIELD); + PARSER.declareString(constructorArg(), PREDICTED_PROBABILITY_FIELD); + PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, n, null), METRICS); + } + + public static BinarySoftClassification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field where the actual class is marked up. + * The value of this field is assumed to either be 1 or 0, or true or false. + */ + private final String actualField; + + /** + * The field of the predicted probability in [0.0, 1.0]. + */ + private final String predictedProbabilityField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public BinarySoftClassification(String actualField, String predictedProbabilityField, EvaluationMetric... metric) { + this(actualField, predictedProbabilityField, Arrays.asList(metric)); + } + + public BinarySoftClassification(String actualField, String predictedProbabilityField, + @Nullable List metrics) { + this.actualField = Objects.requireNonNull(actualField); + this.predictedProbabilityField = Objects.requireNonNull(predictedProbabilityField); + this.metrics = Objects.requireNonNull(metrics); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); + + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BinarySoftClassification that = (BinarySoftClassification) o; + return Objects.equals(actualField, that.actualField) + && Objects.equals(predictedProbabilityField, that.predictedProbabilityField) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedProbabilityField, metrics); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java new file mode 100644 index 0000000000000..d5e4307c9cc74 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java @@ -0,0 +1,206 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class ConfusionMatrixMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "confusion_matrix"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new ConfusionMatrixMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static ConfusionMatrixMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static ConfusionMatrixMetric at(Double... at) { + return new ConfusionMatrixMetric(Arrays.asList(at)); + } + + public ConfusionMatrixMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrixMetric that = (ConfusionMatrixMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, ConfusionMatrix::fromXContent)); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public ConfusionMatrix getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static final class ConfusionMatrix implements ToXContentObject { + + public static ConfusionMatrix fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField TP = new ParseField("tp"); + private static final ParseField FP = new ParseField("fp"); + private static final ParseField TN = new ParseField("tn"); + private static final ParseField FN = new ParseField("fn"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "confusion_matrix", true, args -> new ConfusionMatrix((long) args[0], (long) args[1], (long) args[2], (long) args[3])); + + static { + PARSER.declareLong(constructorArg(), TP); + PARSER.declareLong(constructorArg(), FP); + PARSER.declareLong(constructorArg(), TN); + PARSER.declareLong(constructorArg(), FN); + } + + private final long tp; + private final long fp; + private final long tn; + private final long fn; + + public ConfusionMatrix(long tp, long fp, long tn, long fn) { + this.tp = tp; + this.fp = fp; + this.tn = tn; + this.fn = fn; + } + + public long getTruePositives() { + return tp; + } + + public long getFalsePositives() { + return fp; + } + + public long getTrueNegatives() { + return tn; + } + + public long getFalseNegatives() { + return fn; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(TP.getPreferredName(), tp) + .field(FP.getPreferredName(), fp) + .field(TN.getPreferredName(), tn) + .field(FN.getPreferredName(), fn) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrix that = (ConfusionMatrix) o; + return tp == that.tp && fp == that.fp && tn == that.tn && fn == that.fn; + } + + @Override + public int hashCode() { + return Objects.hash(tp, fp, tn, fn); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java new file mode 100644 index 0000000000000..2a0f1499461d6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class PrecisionMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "precision"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new PrecisionMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static PrecisionMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static PrecisionMetric at(Double... at) { + return new PrecisionMetric(Arrays.asList(at)); + } + + public PrecisionMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PrecisionMetric that = (PrecisionMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, p -> p.doubleValue())); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public Double getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java new file mode 100644 index 0000000000000..505ff1b34d7c5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class RecallMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "recall"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new RecallMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static RecallMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static RecallMetric at(Double... at) { + return new RecallMetric(Arrays.asList(at)); + } + + public RecallMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecallMetric that = (RecallMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, p -> p.doubleValue())); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public Double getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider index 77f1d9700d9a4..dde81e43867d8 100644 --- a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider +++ b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider @@ -1,3 +1,4 @@ org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider -org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider \ No newline at end of file +org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider +org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider \ No newline at end of file diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index eaeeeb59915b8..34872c548e346 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.client.ml.DeleteForecastRequest; import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureRequestTests; import org.elasticsearch.client.ml.FlushJobRequest; @@ -82,6 +83,10 @@ import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.Detector; @@ -747,6 +752,23 @@ public void testDeleteDataFrameAnalytics() { assertNull(request.getEntity()); } + public void testEvaluateDataFrame() throws IOException { + EvaluateDataFrameRequest evaluateRequest = + new EvaluateDataFrameRequest( + Arrays.asList(generateRandomStringArray(1, 10, false, false)), + new BinarySoftClassification( + randomAlphaOfLengthBetween(1, 10), + randomAlphaOfLengthBetween(1, 10), + PrecisionMetric.at(0.5), RecallMetric.at(0.6, 0.7))); + Request request = MLRequestConverters.evaluateDataFrame(evaluateRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/_evaluate", request.getEndpoint()); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) { + EvaluateDataFrameRequest parsedRequest = EvaluateDataFrameRequest.fromXContent(parser); + assertThat(parsedRequest, equalTo(evaluateRequest)); + } + } + public void testPutFilter() throws IOException { MlFilter filter = MlFilterTests.createRandomBuilder("foo").build(); PutFilterRequest putFilterRequest = new PutFilterRequest(filter); @@ -918,6 +940,7 @@ protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); return new NamedXContentRegistry(namedXContent); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index b6224b993dd91..ede6c16e33612 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -44,6 +44,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -119,6 +121,11 @@ import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; import org.elasticsearch.client.ml.dataframe.OutlierDetection; import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -131,6 +138,7 @@ import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.index.query.MatchAllQueryBuilder; @@ -154,6 +162,8 @@ import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -546,7 +556,7 @@ public void testStartDatafeed() throws Exception { String indexName = "start_data_1"; // Set up the index and docs - createIndex(indexName); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); long now = (System.currentTimeMillis()/1000)*1000; @@ -618,7 +628,7 @@ public void testStopDatafeed() throws Exception { String indexName = "stop_data_1"; // Set up the index - createIndex(indexName); + createIndex(indexName, defaultMappingForTest()); // create the job and the datafeed Job job1 = buildJob(jobId1); @@ -680,7 +690,7 @@ public void testGetDatafeedStats() throws Exception { String indexName = "datafeed_stats_data_1"; // Set up the index - createIndex(indexName); + createIndex(indexName, defaultMappingForTest()); // create the job and the datafeed Job job1 = buildJob(jobId1); @@ -747,7 +757,7 @@ public void testPreviewDatafeed() throws Exception { String indexName = "preview_data_1"; // Set up the index and docs - createIndex(indexName); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); long now = (System.currentTimeMillis()/1000)*1000; @@ -802,7 +812,7 @@ public void testDeleteExpiredDataGivenNothingToDelete() throws Exception { private String createExpiredData(String jobId) throws Exception { String indexName = jobId + "-data"; // Set up the index and docs - createIndex(indexName); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -1315,7 +1325,7 @@ public void testGetDataFrameAnalyticsConfig_ConfigNotFound() { public void testGetDataFrameAnalyticsStats() throws Exception { String sourceIndex = "get-stats-test-source-index"; String destIndex = "get-stats-test-dest-index"; - createIndex(sourceIndex); + createIndex(sourceIndex, defaultMappingForTest()); highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); @@ -1352,7 +1362,7 @@ public void testGetDataFrameAnalyticsStats() throws Exception { public void testStartDataFrameAnalyticsConfig() throws Exception { String sourceIndex = "start-test-source-index"; String destIndex = "start-test-dest-index"; - createIndex(sourceIndex); + createIndex(sourceIndex, defaultMappingForTest()); highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. @@ -1379,7 +1389,12 @@ public void testStartDataFrameAnalyticsConfig() throws Exception { new StartDataFrameAnalyticsRequest(configId), machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync); assertTrue(startDataFrameAnalyticsResponse.isAcknowledged()); - assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STARTED)); + assertThat( + getAnalyticsState(configId), + anyOf( + equalTo(DataFrameAnalyticsState.STARTED), + equalTo(DataFrameAnalyticsState.REINDEXING), + equalTo(DataFrameAnalyticsState.ANALYZING))); // Wait for the analytics to stop. assertBusy(() -> assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)), 30, TimeUnit.SECONDS); @@ -1444,19 +1459,116 @@ public void testDeleteDataFrameAnalyticsConfig_ConfigNotFound() { assertThat(exception.status().getStatus(), equalTo(404)); } - private void createIndex(String indexName) throws IOException { - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() + public void testEvaluateDataFrame() throws IOException { + String indexName = "evaluate-test-index"; + createIndex(indexName, mappingForClassification()); + BulkRequest bulk = new BulkRequest() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(docForClassification(indexName, false, 0.1)) // #0 + .add(docForClassification(indexName, false, 0.2)) // #1 + .add(docForClassification(indexName, false, 0.3)) // #2 + .add(docForClassification(indexName, false, 0.4)) // #3 + .add(docForClassification(indexName, false, 0.7)) // #4 + .add(docForClassification(indexName, true, 0.2)) // #5 + .add(docForClassification(indexName, true, 0.3)) // #6 + .add(docForClassification(indexName, true, 0.4)) // #7 + .add(docForClassification(indexName, true, 0.8)) // #8 + .add(docForClassification(indexName, true, 0.9)); // #9 + highLevelClient().bulk(bulk, RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + new BinarySoftClassification( + actualField, + probabilityField, + PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve())); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4)); + + PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME); + assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME)); + // Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9)); + // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.5"), closeTo(0.666666666, 1e-9)); + // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9)); + assertNull(precisionResult.getScoreByThreshold("0.1")); + + RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME); + assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME)); + // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) + assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9)); + // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) + assertThat(recallResult.getScoreByThreshold("0.7"), closeTo(0.4, 1e-9)); + assertNull(recallResult.getScoreByThreshold("0.1")); + + ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME); + assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME)); + ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); + assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9 + assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4 + assertThat(confusionMatrix.getTrueNegatives(), equalTo(4L)); // docs #0, #1, #2 and #3 + assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7 + assertNull(confusionMatrixResult.getScoreByThreshold("0.1")); + + AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); + assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); + assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9)); + assertNotNull(aucRocResult.getCurve()); + List curve = aucRocResult.getCurve(); + AucRocMetric.AucRocPoint curvePointAtThreshold0 = curve.stream().filter(p -> p.getThreshold() == 0.0).findFirst().get(); + assertThat(curvePointAtThreshold0.getTruePositiveRate(), equalTo(1.0)); + assertThat(curvePointAtThreshold0.getFalsePositiveRate(), equalTo(1.0)); + assertThat(curvePointAtThreshold0.getThreshold(), equalTo(0.0)); + AucRocMetric.AucRocPoint curvePointAtThreshold1 = curve.stream().filter(p -> p.getThreshold() == 1.0).findFirst().get(); + assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0)); + assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0)); + assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0)); + } + + private static XContentBuilder defaultMappingForTest() throws IOException { + return XContentFactory.jsonBuilder().startObject() .startObject("properties") - .startObject("timestamp") + .startObject("timestamp") .field("type", "date") .endObject() .startObject("total") .field("type", "long") .endObject() .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + .endObject(); + } + + private static final String actualField = "label"; + private static final String probabilityField = "p"; + + private static XContentBuilder mappingForClassification() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(actualField) + .field("type", "keyword") + .endObject() + .startObject(probabilityField) + .field("type", "double") + .endObject() + .endObject() + .endObject(); + } + + private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) { + return new IndexRequest() + .index(indexName) + .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p); + } + + private void createIndex(String indexName, XContentBuilder mapping) throws IOException { + highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT); } public void testPutFilter() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 6ca2e3c2bdd24..87553c194e76d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -60,6 +60,11 @@ import org.elasticsearch.client.indexlifecycle.UnfollowAction; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -113,6 +118,7 @@ import static org.hamcrest.CoreMatchers.endsWith; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.hasItems; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -668,7 +674,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(22, namedXContents.size()); + assertEquals(31, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -678,7 +684,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 6, categories.size()); + assertEquals("Had: " + categories, 9, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -706,6 +712,12 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); assertTrue(names.contains(TimeSyncConfig.NAME)); + assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); + assertThat(names, hasItems(BinarySoftClassification.NAME)); + assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); + assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java new file mode 100644 index 0000000000000..825adcd2060f8 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class AucRocMetricAucRocPointTests extends AbstractXContentTestCase { + + static AucRocMetric.AucRocPoint randomPoint() { + return new AucRocMetric.AucRocPoint(randomDouble(), randomDouble(), randomDouble()); + } + + @Override + protected AucRocMetric.AucRocPoint createTestInstance() { + return randomPoint(); + } + + @Override + protected AucRocMetric.AucRocPoint doParseInstance(XContentParser parser) throws IOException { + return AucRocMetric.AucRocPoint.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java new file mode 100644 index 0000000000000..9ea7689d60f32 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.client.ml.AucRocMetricAucRocPointTests.randomPoint; + +public class AucRocMetricResultTests extends AbstractXContentTestCase { + + static AucRocMetric.Result randomResult() { + return new AucRocMetric.Result( + randomDouble(), + Stream + .generate(() -> randomPoint()) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + @Override + protected AucRocMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected AucRocMetric.Result doParseInstance(XContentParser parser) throws IOException { + return AucRocMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java new file mode 100644 index 0000000000000..28eb221b318c6 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ConfusionMatrixMetricConfusionMatrixTests extends AbstractXContentTestCase { + + static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() { + return new ConfusionMatrixMetric.ConfusionMatrix(randomInt(), randomInt(), randomInt(), randomInt()); + } + + @Override + protected ConfusionMatrixMetric.ConfusionMatrix createTestInstance() { + return randomConfusionMatrix(); + } + + @Override + protected ConfusionMatrixMetric.ConfusionMatrix doParseInstance(XContentParser parser) throws IOException { + return ConfusionMatrixMetric.ConfusionMatrix.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java new file mode 100644 index 0000000000000..c4b299a96b536 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java @@ -0,0 +1,62 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.client.ml.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix; + +public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase { + + static ConfusionMatrixMetric.Result randomResult() { + return new ConfusionMatrixMetric.Result( + Stream + .generate(() -> randomConfusionMatrix()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected ConfusionMatrixMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected ConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException { + return ConfusionMatrixMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java new file mode 100644 index 0000000000000..b41d113686ccf --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -0,0 +1,76 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; + +public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase { + + public static EvaluateDataFrameResponse randomResponse() { + List metrics = new ArrayList<>(); + if (randomBoolean()) { + metrics.add(AucRocMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(PrecisionMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(RecallMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(ConfusionMatrixMetricResultTests.randomResult()); + } + return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics); + } + + @Override + protected EvaluateDataFrameResponse createTestInstance() { + return randomResponse(); + } + + @Override + protected EvaluateDataFrameResponse doParseInstance(XContentParser parser) throws IOException { + return EvaluateDataFrameResponse.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the metrics map (i.e. alongside named metrics like "precision" or "recall") + return field -> field.isEmpty() || field.contains("."); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java new file mode 100644 index 0000000000000..607adacebb827 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java @@ -0,0 +1,60 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class PrecisionMetricResultTests extends AbstractXContentTestCase { + + static PrecisionMetric.Result randomResult() { + return new PrecisionMetric.Result( + Stream + .generate(() -> randomDouble()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected PrecisionMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected PrecisionMetric.Result doParseInstance(XContentParser parser) throws IOException { + return PrecisionMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java new file mode 100644 index 0000000000000..138875007e30d --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java @@ -0,0 +1,60 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class RecallMetricResultTests extends AbstractXContentTestCase { + + static RecallMetric.Result randomResult() { + return new RecallMetric.Result( + Stream + .generate(() -> randomDouble()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected RecallMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected RecallMetric.Result doParseInstance(XContentParser parser) throws IOException { + return RecallMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} From bec3b83f306472829912ad12b03c20685fb33fee Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 31 May 2019 19:20:26 +0300 Subject: [PATCH 55/67] [FEATURE][ML] Add allow_no_match to df-analytics get and get-stats APIs (#42752) --- .../client/MLRequestConverters.java | 7 ++ .../ml/GetDataFrameAnalyticsRequest.java | 22 ++++++- .../ml/GetDataFrameAnalyticsStatsRequest.java | 22 ++++++- .../client/MLRequestConvertersTests.java | 6 +- .../action/GetDataFrameAnalyticsAction.java | 2 + .../GetDataFrameAnalyticsStatsAction.java | 18 +++++- .../ml/qa/ml-with-security/build.gradle | 5 ++ ...sportGetDataFrameAnalyticsStatsAction.java | 1 + .../RestGetDataFrameAnalyticsAction.java | 3 +- .../RestGetDataFrameAnalyticsStatsAction.java | 2 + .../api/ml.get_data_frame_analytics.json | 12 +++- .../ml.get_data_frame_analytics_stats.json | 12 +++- .../test/ml/data_frame_analytics_crud.yml | 64 +++++++++++++++++++ 13 files changed, 165 insertions(+), 11 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 11c7c1a4cc40e..9d195fc0a530c 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -613,6 +613,9 @@ static Request getDataFrameAnalytics(GetDataFrameAnalyticsRequest getRequest) { params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); } } + if (getRequest.getAllowNoMatch() != null) { + params.putParam(GetDataFrameAnalyticsRequest.ALLOW_NO_MATCH.getPreferredName(), Boolean.toString(getRequest.getAllowNoMatch())); + } request.addParameters(params.asMap()); return request; } @@ -634,6 +637,10 @@ static Request getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest getS params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); } } + if (getStatsRequest.getAllowNoMatch() != null) { + params.putParam(GetDataFrameAnalyticsStatsRequest.ALLOW_NO_MATCH.getPreferredName(), + Boolean.toString(getStatsRequest.getAllowNoMatch())); + } request.addParameters(params.asMap()); return request; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java index d72742ea3377b..40698c4b528fa 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java @@ -23,6 +23,7 @@ import org.elasticsearch.client.ValidationException; import org.elasticsearch.client.core.PageParams; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; import java.util.Arrays; import java.util.List; @@ -31,7 +32,10 @@ public class GetDataFrameAnalyticsRequest implements Validatable { + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + private final List ids; + private Boolean allowNoMatch; private PageParams pageParams; /** @@ -50,6 +54,21 @@ public List getIds() { return ids; } + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + /** + * Whether to ignore if a wildcard expression matches no data frame analytics. + * + * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) + * does not match any data frame analytics + */ + public GetDataFrameAnalyticsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + public PageParams getPageParams() { return pageParams; } @@ -74,11 +93,12 @@ public boolean equals(Object o) { GetDataFrameAnalyticsRequest other = (GetDataFrameAnalyticsRequest) o; return Objects.equals(ids, other.ids) + && Objects.equals(allowNoMatch, other.allowNoMatch) && Objects.equals(pageParams, other.pageParams); } @Override public int hashCode() { - return Objects.hash(ids, pageParams); + return Objects.hash(ids, allowNoMatch, pageParams); } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java index 84bef6894213e..f1e4a35fb661b 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java @@ -23,6 +23,7 @@ import org.elasticsearch.client.ValidationException; import org.elasticsearch.client.core.PageParams; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; import java.util.Arrays; import java.util.List; @@ -34,7 +35,10 @@ */ public class GetDataFrameAnalyticsStatsRequest implements Validatable { + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + private final List ids; + private Boolean allowNoMatch; private PageParams pageParams; public GetDataFrameAnalyticsStatsRequest(String... ids) { @@ -45,6 +49,21 @@ public List getIds() { return ids; } + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + /** + * Whether to ignore if a wildcard expression matches no data frame analytics. + * + * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) + * does not match any data frame analytics + */ + public GetDataFrameAnalyticsStatsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + public PageParams getPageParams() { return pageParams; } @@ -69,11 +88,12 @@ public boolean equals(Object o) { GetDataFrameAnalyticsStatsRequest other = (GetDataFrameAnalyticsStatsRequest) o; return Objects.equals(ids, other.ids) + && Objects.equals(allowNoMatch, other.allowNoMatch) && Objects.equals(pageParams, other.pageParams); } @Override public int hashCode() { - return Objects.hash(ids, pageParams); + return Objects.hash(ids, allowNoMatch, pageParams); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 34872c548e346..2a478718b6ffe 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -703,12 +703,13 @@ public void testGetDataFrameAnalytics() { String configId2 = randomAlphaOfLength(10); String configId3 = randomAlphaOfLength(10); GetDataFrameAnalyticsRequest getRequest = new GetDataFrameAnalyticsRequest(configId1, configId2, configId3) + .setAllowNoMatch(false) .setPageParams(new PageParams(100, 300)); Request request = MLRequestConverters.getDataFrameAnalytics(getRequest); assertEquals(HttpGet.METHOD_NAME, request.getMethod()); assertEquals("/_ml/data_frame/analytics/" + configId1 + "," + configId2 + "," + configId3, request.getEndpoint()); - assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"))); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false"))); assertNull(request.getEntity()); } @@ -717,12 +718,13 @@ public void testGetDataFrameAnalyticsStats() { String configId2 = randomAlphaOfLength(10); String configId3 = randomAlphaOfLength(10); GetDataFrameAnalyticsStatsRequest getStatsRequest = new GetDataFrameAnalyticsStatsRequest(configId1, configId2, configId3) + .setAllowNoMatch(false) .setPageParams(new PageParams(100, 300)); Request request = MLRequestConverters.getDataFrameAnalyticsStats(getStatsRequest); assertEquals(HttpGet.METHOD_NAME, request.getMethod()); assertEquals("/_ml/data_frame/analytics/" + configId1 + "," + configId2 + "," + configId3 + "/_stats", request.getEndpoint()); - assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"))); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false"))); assertNull(request.getEntity()); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java index b689da03ee4f0..92233fbb27692 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java @@ -34,6 +34,8 @@ public Response newResponse() { public static class Request extends AbstractGetResourcesRequest { + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + public Request() { setAllowNoResources(true); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java index 5b19da2766129..b14feaa8839f5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -55,7 +56,10 @@ public Writeable.Reader getResponseReader() { public static class Request extends BaseTasksRequest { + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + private String id; + private boolean allowNoMatch = true; private PageParams pageParams = PageParams.defaultParams(); // Used internally to store the expanded IDs @@ -71,6 +75,7 @@ public Request() {} public Request(StreamInput in) throws IOException { super(in); id = in.readString(); + allowNoMatch = in.readBoolean(); pageParams = in.readOptionalWriteable(PageParams::new); expandedIds = in.readStringList(); } @@ -87,6 +92,7 @@ public List getExpandedIds() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(id); + out.writeBoolean(allowNoMatch); out.writeOptionalWriteable(pageParams); out.writeStringCollection(expandedIds); } @@ -99,6 +105,14 @@ public String getId() { return id; } + public boolean isAllowNoMatch() { + return allowNoMatch; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + public void setPageParams(PageParams pageParams) { this.pageParams = pageParams; } @@ -119,7 +133,7 @@ public ActionRequestValidationException validate() { @Override public int hashCode() { - return Objects.hash(id, pageParams); + return Objects.hash(id, allowNoMatch, pageParams); } @Override @@ -131,7 +145,7 @@ public boolean equals(Object obj) { return false; } Request other = (Request) obj; - return Objects.equals(id, other.id) && Objects.equals(pageParams, other.pageParams); + return Objects.equals(id, other.id) && allowNoMatch == other.allowNoMatch && Objects.equals(pageParams, other.pageParams); } } diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index d73f5a71de3e2..e85603b6aa89b 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -52,6 +52,11 @@ integTestRunner { 'ml/data_frame_analytics_crud/Test put config given missing analysis', 'ml/data_frame_analytics_crud/Test put config given empty analysis', 'ml/data_frame_analytics_crud/Test get given missing analytics', + 'ml/data_frame_analytics_crud/Test get given missing analytics and allow_no_match is false', + 'ml/data_frame_analytics_crud/Test get given expression without matches and allow_no_match is false', + 'ml/data_frame_analytics_crud/Test get stats given missing analytics', + 'ml/data_frame_analytics_crud/Test get stats given missing analytics and allow_no_match is false', + 'ml/data_frame_analytics_crud/Test get stats given expression without matches and allow_no_match is false', 'ml/data_frame_analytics_crud/Test delete given missing config', 'ml/data_frame_analytics_crud/Test max model memory limit', 'ml/evaluate_data_frame/Test given missing index', diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java index ec6c9371ea405..575069e4fd4dc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java @@ -148,6 +148,7 @@ protected void doExecute(Task task, GetDataFrameAnalyticsStatsAction.Request req GetDataFrameAnalyticsAction.Request getRequest = new GetDataFrameAnalyticsAction.Request(); getRequest.setResourceId(request.getId()); + getRequest.setAllowNoResources(request.isAllowNoMatch()); getRequest.setPageParams(request.getPageParams()); executeAsyncWithOrigin(client, ML_ORIGIN, GetDataFrameAnalyticsAction.INSTANCE, getRequest, getResponseListener); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java index 938694065a7ea..b37ff2b7e5982 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsAction.java @@ -44,7 +44,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); } - + request.setAllowNoResources(restRequest.paramAsBoolean(GetDataFrameAnalyticsAction.Request.ALLOW_NO_MATCH.getPreferredName(), + request.isAllowNoResources())); return channel -> client.execute(GetDataFrameAnalyticsAction.INSTANCE, request, new RestToXContentListener<>(channel)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java index 8f1781ba75fc3..3c363762817ba 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/dataframe/RestGetDataFrameAnalyticsStatsAction.java @@ -44,6 +44,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient request.setPageParams(new PageParams(restRequest.paramAsInt(PageParams.FROM.getPreferredName(), PageParams.DEFAULT_FROM), restRequest.paramAsInt(PageParams.SIZE.getPreferredName(), PageParams.DEFAULT_SIZE))); } + request.setAllowNoMatch(restRequest.paramAsBoolean(GetDataFrameAnalyticsStatsAction.Request.ALLOW_NO_MATCH.getPreferredName(), + request.isAllowNoMatch())); return channel -> client.execute(GetDataFrameAnalyticsStatsAction.INSTANCE, request, new RestToXContentListener<>(channel)); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json index 9c65661d254a4..491ced8b541da 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json @@ -14,13 +14,21 @@ } }, "params": { + "allow_no_match": { + "type": "boolean", + "required": false, + "description": "Whether to ignore if a wildcard expression matches no data frame analytics. (This includes `_all` string or when no data frame analytics have been specified)", + "default": true + }, "from": { "type": "int", - "description": "skips a number of analytics" + "description": "skips a number of analytics", + "default": 0 }, "size": { "type": "int", - "description": "specifies a max number of analytics to get" + "description": "specifies a max number of analytics to get", + "default": 100 } } }, diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json index d74f5880c72de..87ffe6c0fd722 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json @@ -14,13 +14,21 @@ } }, "params": { + "allow_no_match": { + "type": "boolean", + "required": false, + "description": "Whether to ignore if a wildcard expression matches no data frame analytics. (This includes `_all` string or when no data frame analytics have been specified)", + "default": true + }, "from": { "type": "int", - "description": "skips a number of analytics" + "description": "skips a number of analytics", + "default": 0 }, "size": { "type": "int", - "description": "specifies a max number of analytics to get" + "description": "specifies a max number of analytics to get", + "default": 100 } } }, diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 5dc265a74da22..ceb8cd04a6f08 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -427,6 +427,70 @@ ml.get_data_frame_analytics: id: "missing-analytics" +--- +"Test get given missing analytics and allow_no_match is false": + + - do: + catch: missing + ml.get_data_frame_analytics: + id: "missing-analytics" + allow_no_match: false + +--- +"Test get given expression without matches and allow_no_match is false": + + - do: + catch: missing + ml.get_data_frame_analytics: + id: "missing-analytics*" + allow_no_match: false + +--- +"Test get given expression without matches and allow_no_match is true": + + - do: + ml.get_data_frame_analytics: + id: "missing-analytics*" + allow_no_match: true + - match: { count: 0 } + - match: { data_frame_analytics: [] } + +--- +"Test get stats given missing analytics": + + - do: + catch: missing + ml.get_data_frame_analytics_stats: + id: "missing-analytics" + +--- +"Test get stats given missing analytics and allow_no_match is false": + + - do: + catch: missing + ml.get_data_frame_analytics_stats: + id: "missing-analytics" + allow_no_match: false + +--- +"Test get stats given expression without matches and allow_no_match is false": + + - do: + catch: missing + ml.get_data_frame_analytics_stats: + id: "missing-analytics*" + allow_no_match: false + +--- +"Test get stats given expression without matches and allow_no_match is true": + + - do: + ml.get_data_frame_analytics_stats: + id: "missing-analytics*" + allow_no_match: true + - match: { count: 0 } + - match: { data_frame_analytics: [] } + --- "Test get stats given multiple analytics": From dc017a4e2381bfb6fe742df8495bf41c92aaf4f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Fri, 31 May 2019 19:20:27 +0200 Subject: [PATCH 56/67] [FEATURE][ML] Client-side "Stop" API method for DataFrame analytics config. (#42756) Implement HLRC for Data Frame Analytics Stop API --- .../client/MLRequestConverters.java | 20 +++++ .../client/MachineLearningClient.java | 44 +++++++++- .../ml/StartDataFrameAnalyticsRequest.java | 5 -- .../ml/StopDataFrameAnalyticsRequest.java | 88 +++++++++++++++++++ .../ml/StopDataFrameAnalyticsResponse.java | 87 ++++++++++++++++++ .../client/MLRequestConvertersTests.java | 20 +++++ .../client/MachineLearningIT.java | 41 +++++++++ .../StartDataFrameAnalyticsRequestTests.java | 8 +- .../StopDataFrameAnalyticsRequestTests.java | 43 +++++++++ .../StopDataFrameAnalyticsResponseTests.java | 42 +++++++++ .../action/StopDataFrameAnalyticsAction.java | 12 ++- 11 files changed, 393 insertions(+), 17 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index 9d195fc0a530c..e5a98b4632432 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -73,6 +73,7 @@ import org.elasticsearch.client.ml.SetUpgradeModeRequest; import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.UpdateDatafeedRequest; import org.elasticsearch.client.ml.UpdateFilterRequest; @@ -660,6 +661,25 @@ static Request startDataFrameAnalytics(StartDataFrameAnalyticsRequest startReque return request; } + static Request stopDataFrameAnalytics(StopDataFrameAnalyticsRequest stopRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(stopRequest.getId()) + .addPathPartAsIs("_stop") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (stopRequest.getTimeout() != null) { + params.withTimeout(stopRequest.getTimeout()); + } + if (stopRequest.getAllowNoMatch() != null) { + params.putParam( + StopDataFrameAnalyticsRequest.ALLOW_NO_MATCH.getPreferredName(), Boolean.toString(stopRequest.getAllowNoMatch())); + } + request.addParameters(params.asMap()); + return request; + } + static Request deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest deleteRequest) { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml", "data_frame", "analytics") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 002ebd1c45bfc..ea72c355a02e7 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -99,6 +99,8 @@ import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -2030,7 +2032,7 @@ public AcknowledgedResponse startDataFrameAnalytics(StartDataFrameAnalyticsReque Collections.emptySet()); } - /** + /** * Starts Data Frame Analytics asynchronously and notifies listener upon completion *

* For additional info @@ -2050,6 +2052,46 @@ public void startDataFrameAnalyticsAsync(StartDataFrameAnalyticsRequest request, Collections.emptySet()); } + /** + * Stops Data Frame Analytics + *

+ * For additional info + * see Stop Data Frame Analytics documentation + * + * @param request The {@link StopDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link StopDataFrameAnalyticsResponse} + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public StopDataFrameAnalyticsResponse stopDataFrameAnalytics(StopDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::stopDataFrameAnalytics, + options, + StopDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Stops Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see Stop Data Frame Analytics documentation + * + * @param request The {@link StopDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void stopDataFrameAnalyticsAsync(StopDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::stopDataFrameAnalytics, + options, + StopDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + /** * Deletes the given Data Frame Analytics config *

diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java index 16e43180d57fe..68a925d15019a 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java @@ -33,12 +33,7 @@ public class StartDataFrameAnalyticsRequest implements Validatable { private TimeValue timeout; public StartDataFrameAnalyticsRequest(String id) { - this(id, null); - } - - public StartDataFrameAnalyticsRequest(String id, @Nullable TimeValue timeout) { this.id = id; - this.timeout = timeout; } public String getId() { diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..9608d40fc7d16 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java @@ -0,0 +1,88 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; + +import java.util.Objects; +import java.util.Optional; + +public class StopDataFrameAnalyticsRequest implements Validatable { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private final String id; + private TimeValue timeout; + private Boolean allowNoMatch; + + public StopDataFrameAnalyticsRequest(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public TimeValue getTimeout() { + return timeout; + } + + public StopDataFrameAnalyticsRequest setTimeout(@Nullable TimeValue timeout) { + this.timeout = timeout; + return this; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + public StopDataFrameAnalyticsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StopDataFrameAnalyticsRequest other = (StopDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id) + && Objects.equals(timeout, other.timeout) + && Objects.equals(allowNoMatch, other.allowNoMatch); + } + + @Override + public int hashCode() { + return Objects.hash(id, timeout, allowNoMatch); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..5f45c6f9ea51f --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +/** + * Response indicating if the Machine Learning Data Frame Analytics is now stopped or not + */ +public class StopDataFrameAnalyticsResponse implements ToXContentObject { + + private static final ParseField STOPPED = new ParseField("stopped"); + + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "stop_data_frame_analytics_response", + true, + args -> new StopDataFrameAnalyticsResponse((Boolean) args[0])); + + static { + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), STOPPED); + } + + public static StopDataFrameAnalyticsResponse fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final boolean stopped; + + public StopDataFrameAnalyticsResponse(boolean stopped) { + this.stopped = stopped; + } + + /** + * Has the Data Frame Analytics stopped or not + * + * @return boolean value indicating the Data Frame Analytics stopped status + */ + public boolean isStopped() { + return stopped; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StopDataFrameAnalyticsResponse other = (StopDataFrameAnalyticsResponse) o; + return stopped == other.stopped; + } + + @Override + public int hashCode() { + return Objects.hash(stopped); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(STOPPED.getPreferredName(), stopped) + .endObject(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index 2a478718b6ffe..36d71df5f91bb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -71,6 +71,7 @@ import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedRequestTests; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.UpdateFilterRequest; import org.elasticsearch.client.ml.UpdateJobRequest; @@ -746,6 +747,25 @@ public void testStartDataFrameAnalytics_WithTimeout() { assertNull(request.getEntity()); } + public void testStopDataFrameAnalytics() { + StopDataFrameAnalyticsRequest stopRequest = new StopDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.stopDataFrameAnalytics(stopRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + stopRequest.getId() + "/_stop", request.getEndpoint()); + assertNull(request.getEntity()); + } + + public void testStopDataFrameAnalytics_WithParams() { + StopDataFrameAnalyticsRequest stopRequest = new StopDataFrameAnalyticsRequest(randomAlphaOfLength(10)) + .setTimeout(TimeValue.timeValueMinutes(1)) + .setAllowNoMatch(false); + Request request = MLRequestConverters.stopDataFrameAnalytics(stopRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + stopRequest.getId() + "/_stop", request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("timeout", "1m"), hasEntry("allow_no_match", "false"))); + assertNull(request.getEntity()); + } + public void testDeleteDataFrameAnalytics() { DeleteDataFrameAnalyticsRequest deleteRequest = new DeleteDataFrameAnalyticsRequest(randomAlphaOfLength(10)); Request request = MLRequestConverters.deleteDataFrameAnalytics(deleteRequest); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index ede6c16e33612..dab8040be18ab 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -99,6 +99,8 @@ import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -1403,6 +1405,45 @@ public void testStartDataFrameAnalyticsConfig() throws Exception { assertTrue(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); } + public void testStopDataFrameAnalyticsConfig() throws Exception { + String sourceIndex = "stop-test-source-index"; + String destIndex = "stop-test-dest-index"; + createIndex(sourceIndex); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + + // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. + assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "stop-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + + AcknowledgedResponse startDataFrameAnalyticsResponse = execute( + new StartDataFrameAnalyticsRequest(configId), + machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync); + assertTrue(startDataFrameAnalyticsResponse.isAcknowledged()); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STARTED)); + + StopDataFrameAnalyticsResponse stopDataFrameAnalyticsResponse = execute( + new StopDataFrameAnalyticsRequest(configId), + machineLearningClient::stopDataFrameAnalytics, machineLearningClient::stopDataFrameAnalyticsAsync); + assertTrue(stopDataFrameAnalyticsResponse.isStopped()); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + } + private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOException { MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); GetDataFrameAnalyticsStatsResponse statsResponse = diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java index 97367730561cf..6e43b50bcd12b 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java @@ -30,14 +30,14 @@ public class StartDataFrameAnalyticsRequestTests extends ESTestCase { public void testValidate_Ok() { assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").validate()); - assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo", null).validate()); - assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo", TimeValue.ZERO).validate()); + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").setTimeout(null).validate()); + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").setTimeout(TimeValue.ZERO).validate()); } public void testValidate_Failure() { - assertThat(new StartDataFrameAnalyticsRequest(null, null).validate().get().getMessage(), + assertThat(new StartDataFrameAnalyticsRequest(null).validate().get().getMessage(), containsString("data frame analytics id must not be null")); - assertThat(new StartDataFrameAnalyticsRequest(null, TimeValue.ZERO).validate().get().getMessage(), + assertThat(new StartDataFrameAnalyticsRequest(null).setTimeout(TimeValue.ZERO).validate().get().getMessage(), containsString("data frame analytics id must not be null")); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..57af2083743ae --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class StopDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").validate()); + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").setTimeout(null).validate()); + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").setTimeout(TimeValue.ZERO).validate()); + } + + public void testValidate_Failure() { + assertThat(new StopDataFrameAnalyticsRequest(null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + assertThat(new StopDataFrameAnalyticsRequest(null).setTimeout(TimeValue.ZERO).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java new file mode 100644 index 0000000000000..55ef1aed7534a --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java @@ -0,0 +1,42 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class StopDataFrameAnalyticsResponseTests extends AbstractXContentTestCase { + + @Override + protected StopDataFrameAnalyticsResponse createTestInstance() { + return new StopDataFrameAnalyticsResponse(randomBoolean()); + } + + @Override + protected StopDataFrameAnalyticsResponse doParseInstance(XContentParser parser) throws IOException { + return StopDataFrameAnalyticsResponse.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java index 69aa60501fdac..43d382147fd64 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java @@ -132,13 +132,11 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (id != null) { - builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); - } - builder.field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch); - builder.endObject(); - return builder; + return builder + .startObject() + .field(DataFrameAnalyticsConfig.ID.getPreferredName(), id) + .field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch) + .endObject(); } @Override From 7d43d6ecf31f0f555ea9234d5bf3803bcd8296a3 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 3 Jun 2019 10:16:45 +0300 Subject: [PATCH 57/67] [FEATURE][ML] Ignore `index.version.upgraded` index setting (#42768) When a data frame analytics job reindexes the source index into the dest index, we try to preserve settings. However, some settings are internal and are not valid during index creation. This commit adds `index.version.upgraded` to the list of ignored settings. --- .../xpack/ml/dataframe/DataFrameAnalyticsManager.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index c0f394b6dc890..37bded7c3c712 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -60,7 +60,8 @@ public class DataFrameAnalyticsManager { "index.creation_date", "index.provided_name", "index.uuid", - "index.version.created" + "index.version.created", + "index.version.upgraded" ); private final ClusterService clusterService; From 69f5652881a4537a0db6d9a6a987acc53db76a58 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 4 Jun 2019 18:11:09 +0300 Subject: [PATCH 58/67] [ML] Fix compilation error from upstream --- .../test/java/org/elasticsearch/client/MachineLearningIT.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index dab8040be18ab..a5055b7e4f05d 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1408,7 +1408,7 @@ public void testStartDataFrameAnalyticsConfig() throws Exception { public void testStopDataFrameAnalyticsConfig() throws Exception { String sourceIndex = "stop-test-source-index"; String destIndex = "stop-test-dest-index"; - createIndex(sourceIndex); + createIndex(sourceIndex, mappingForClassification()); highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. From 44856a51d716d71d080826388a6ea4f67814d386 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Wed, 5 Jun 2019 17:20:31 +0300 Subject: [PATCH 59/67] =?UTF-8?q?[FEATURE][ML]=20auc=5Froc=20cannot=20be?= =?UTF-8?q?=20calculated=20when=20there=20are=20no=20inliers/=E2=80=A6=20(?= =?UTF-8?q?#42853)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Also fixes a bug with the matching query for binary soft classification --- .../evaluation/softclassification/AucRoc.java | 16 +++-- .../BinarySoftClassification.java | 2 +- .../ml/qa/ml-with-security/build.gradle | 2 + .../TransportEvaluateDataFrameAction.java | 8 ++- .../test/ml/evaluate_data_frame.yml | 70 ++++++++++++++++--- 5 files changed, 84 insertions(+), 14 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java index 9125a50cd85c5..228dac00bfb68 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java @@ -21,6 +21,7 @@ import org.elasticsearch.search.aggregations.bucket.filter.Filter; import org.elasticsearch.search.aggregations.metrics.Percentiles; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.ArrayList; @@ -145,16 +146,23 @@ private String restLabelsAggName(ClassInfo classInfo) { public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo)); Filter restAgg = aggs.get(restLabelsAggName(classInfo)); - double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES)); - double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES)); + double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES), + "[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]"); + double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES), + "[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]"); List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); double aucRocScore = calculateAucScore(aucRocCurve); return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()); } - private static double[] percentilesArray(Percentiles percentiles) { + private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) { double[] result = new double[99]; - percentiles.forEach(percentile -> result[((int) percentile.getPercent()) - 1] = percentile.getValue()); + percentiles.forEach(percentile -> { + if (Double.isNaN(percentile.getValue())) { + throw ExceptionsHelper.badRequestException(errorIfUndefined); + } + result[((int) percentile.getPercent()) - 1] = percentile.getValue(); + }); return result; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java index 732fcdfb44eae..f594e7598fc20 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -192,7 +192,7 @@ public void evaluate(SearchResponse searchResponse, ActionListener threadPool.generic().execute(() -> evaluation.evaluate(searchResponse, resultsListener)), + searchResponse -> threadPool.generic().execute(() -> { + try { + evaluation.evaluate(searchResponse, resultsListener); + } catch (Exception e) { + listener.onFailure(e); + }; + }), listener::onFailure )); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml index 3bbb59c205fb9..6c41edeb4026d 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml @@ -7,7 +7,9 @@ setup: { "is_outlier": false, "is_outlier_int": 0, - "outlier_score": 0.0 + "outlier_score": 0.0, + "all_true_field": true, + "all_false_field": false } - do: @@ -17,7 +19,9 @@ setup: { "is_outlier": false, "is_outlier_int": 0, - "outlier_score": 0.2 + "outlier_score": 0.2, + "all_true_field": true, + "all_false_field": false } - do: @@ -27,7 +31,9 @@ setup: { "is_outlier": false, "is_outlier_int": 0, - "outlier_score": 0.3 + "outlier_score": 0.3, + "all_true_field": true, + "all_false_field": false } - do: @@ -37,7 +43,9 @@ setup: { "is_outlier": true, "is_outlier_int": 1, - "outlier_score": 0.3 + "outlier_score": 0.3, + "all_true_field": true, + "all_false_field": false } - do: @@ -47,7 +55,9 @@ setup: { "is_outlier": true, "is_outlier_int": 1, - "outlier_score": 0.4 + "outlier_score": 0.4, + "all_true_field": true, + "all_false_field": false } - do: @@ -57,7 +67,9 @@ setup: { "is_outlier": true, "is_outlier_int": 1, - "outlier_score": 0.5 + "outlier_score": 0.5, + "all_true_field": true, + "all_false_field": false } - do: @@ -67,7 +79,9 @@ setup: { "is_outlier": true, "is_outlier_int": 1, - "outlier_score": 0.9 + "outlier_score": 0.9, + "all_true_field": true, + "all_false_field": false } - do: @@ -77,7 +91,9 @@ setup: { "is_outlier": true, "is_outlier_int": 1, - "outlier_score": 0.95 + "outlier_score": 0.95, + "all_true_field": true, + "all_false_field": false } # This document misses the required fields and should be ignored @@ -152,6 +168,44 @@ setup: - match: { binary_soft_classification.auc_roc.score: 0.9899 } - is_true: binary_soft_classification.auc_roc.curve +--- +"Test binary_soft_classifition auc_roc given actual_field is always true": + - do: + catch: /\[auc_roc\] requires at least one actual_field to have a different value than \[true\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "all_true_field", + "predicted_probability_field": "outlier_score", + "metrics": { + "auc_roc": {} + } + } + } + } + +--- +"Test binary_soft_classifition auc_roc given actual_field is always false": + - do: + catch: /\[auc_roc\] requires at least one actual_field to have the value \[true\]/ + ml.evaluate_data_frame: + body: > + { + "index": "utopia", + "evaluation": { + "binary_soft_classification": { + "actual_field": "all_false_field", + "predicted_probability_field": "outlier_score", + "metrics": { + "auc_roc": {} + } + } + } + } + --- "Test binary_soft_classifition precision": - do: From a2268a27e1c72c0fc89d837c3451a0276abcd8be Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 7 Jun 2019 13:36:38 +0300 Subject: [PATCH 60/67] [ML] Fix compilation after upstream changes --- .../ml/action/GetDataFrameAnalyticsActionResponseTests.java | 4 ++-- .../ml/action/PutDataFrameAnalyticsActionRequestTests.java | 4 ++-- .../ml/action/PutDataFrameAnalyticsActionResponseTests.java | 4 ++-- .../core/ml/dataframe/DataFrameAnalyticsConfigTests.java | 4 ++-- .../core/ml/dataframe/DataFrameAnalyticsSourceTests.java | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java index 38a3396316602..8a7b6717abd92 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java @@ -26,7 +26,7 @@ public class GetDataFrameAnalyticsActionResponseTests extends AbstractStreamable protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); - namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -34,7 +34,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); - namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java index d00fa4384be8a..1e5416d5a5dce 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java @@ -33,7 +33,7 @@ public void setUpId() { protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); - namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -41,7 +41,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); - namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java index c9f678b13df2a..d323505828e42 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java @@ -24,7 +24,7 @@ public class PutDataFrameAnalyticsActionResponseTests extends AbstractStreamable protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); - namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -32,7 +32,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); - namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index a5df1f83c3d37..dd9b229913aa9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -56,7 +56,7 @@ protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws protected NamedWriteableRegistry getNamedWriteableRegistry() { List namedWriteables = new ArrayList<>(); namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); - namedWriteables.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); return new NamedWriteableRegistry(namedWriteables); } @@ -64,7 +64,7 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { protected NamedXContentRegistry xContentRegistry() { List namedXContent = new ArrayList<>(); namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); - namedXContent.addAll(new SearchModule(Settings.EMPTY, false, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java index 7783354d425a9..8c42dfb7a4cb7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSourceTests.java @@ -23,13 +23,13 @@ public class DataFrameAnalyticsSourceTests extends AbstractSerializingTestCase Date: Tue, 11 Jun 2019 16:45:18 +0300 Subject: [PATCH 61/67] [FEATURE][ML] Ensure data extractor is not leaking scroll contexts (#42960) --- .../client/MachineLearningIT.java | 13 +- .../extractor/DataFrameDataExtractor.java | 50 +-- .../DataFrameDataExtractorTests.java | 341 ++++++++++++++++++ 3 files changed, 373 insertions(+), 31 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index a5055b7e4f05d..d6550964f9732 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -164,7 +164,6 @@ import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; @@ -1365,7 +1364,8 @@ public void testStartDataFrameAnalyticsConfig() throws Exception { String sourceIndex = "start-test-source-index"; String destIndex = "start-test-dest-index"; createIndex(sourceIndex, defaultMappingForTest()); - highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); @@ -1391,12 +1391,6 @@ public void testStartDataFrameAnalyticsConfig() throws Exception { new StartDataFrameAnalyticsRequest(configId), machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync); assertTrue(startDataFrameAnalyticsResponse.isAcknowledged()); - assertThat( - getAnalyticsState(configId), - anyOf( - equalTo(DataFrameAnalyticsState.STARTED), - equalTo(DataFrameAnalyticsState.REINDEXING), - equalTo(DataFrameAnalyticsState.ANALYZING))); // Wait for the analytics to stop. assertBusy(() -> assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)), 30, TimeUnit.SECONDS); @@ -1409,7 +1403,8 @@ public void testStopDataFrameAnalyticsConfig() throws Exception { String sourceIndex = "stop-test-source-index"; String destIndex = "stop-test-dest-index"; createIndex(sourceIndex, mappingForClassification()); - highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index a45185ebe213f..7b8452f635f9e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -7,6 +7,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.search.ClearScrollAction; import org.elasticsearch.action.search.ClearScrollRequest; import org.elasticsearch.action.search.SearchAction; @@ -20,7 +21,6 @@ import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; -import org.elasticsearch.xpack.core.ml.datafeed.extractor.ExtractorUtils; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsFields; @@ -34,6 +34,7 @@ import java.util.Objects; import java.util.Optional; import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; import java.util.stream.Collectors; /** @@ -91,9 +92,28 @@ public Optional> next() throws IOException { protected List initScroll() throws IOException { LOGGER.debug("[{}] Initializing scroll", context.jobId); - SearchResponse searchResponse = executeSearchRequest(buildSearchRequest()); - LOGGER.debug("[{}] Search response was obtained", context.jobId); - return processSearchResponse(searchResponse); + return tryRequestWithSearchResponse(() -> executeSearchRequest(buildSearchRequest())); + } + + private List tryRequestWithSearchResponse(Supplier request) throws IOException { + try { + // We've set allow_partial_search_results to false which means if something + // goes wrong the request will throw. + SearchResponse searchResponse = request.get(); + LOGGER.debug("[{}] Search response was obtained", context.jobId); + + // Request was successful so we can restore the flag to retry if a future failure occurs + searchHasShardFailure = false; + + return processSearchResponse(searchResponse); + } catch (Exception e) { + if (searchHasShardFailure) { + throw e; + } + LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", context.jobId), e); + markScrollAsErrored(); + return initScroll(); + } } protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) { @@ -103,6 +123,8 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest private SearchRequestBuilder buildSearchRequest() { SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE) .setScroll(SCROLL_TIMEOUT) + // This ensures the search throws if there are failures and the scroll context gets cleared automatically + .setAllowPartialSearchResults(false) .addSort(DataFrameAnalyticsFields.ID, SortOrder.ASC) .setIndices(context.indices) .setSize(context.scrollSize) @@ -117,14 +139,6 @@ private SearchRequestBuilder buildSearchRequest() { } private List processSearchResponse(SearchResponse searchResponse) throws IOException { - - if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) { - LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId); - markScrollAsErrored(); - return initScroll(); - } - - ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse); scrollId = searchResponse.getScrollId(); if (searchResponse.getHits().getHits().length == 0) { hasNext = false; @@ -143,7 +157,6 @@ private List processSearchResponse(SearchResponse searchResponse) throws IO rows.add(createRow(hit)); } return rows; - } private Row createRow(SearchHit hit) { @@ -163,15 +176,13 @@ private Row createRow(SearchHit hit) { private List continueScroll() throws IOException { LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId); - SearchResponse searchResponse = executeSearchScrollRequest(scrollId); - LOGGER.debug("[{}] Search response was obtained", context.jobId); - return processSearchResponse(searchResponse); + return tryRequestWithSearchResponse(() -> executeSearchScrollRequest(scrollId)); } private void markScrollAsErrored() { // This could be a transient error with the scroll Id. // Reinitialise the scroll and try again but only once. - resetScroll(); + scrollId = null; searchHasShardFailure = true; } @@ -183,11 +194,6 @@ protected SearchResponse executeSearchScrollRequest(String scrollId) { .get()); } - private void resetScroll() { - clearScroll(scrollId); - scrollId = null; - } - private void clearScroll(String scrollId) { if (scrollId != null) { ClearScrollRequest request = new ClearScrollRequest(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java new file mode 100644 index 0000000000000..f6547e1e6e583 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -0,0 +1,341 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.extractor; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.search.ClearScrollAction; +import org.elasticsearch.action.search.ClearScrollRequest; +import org.elasticsearch.action.search.ClearScrollResponse; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; +import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Queue; +import java.util.stream.Collectors; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.mockito.Matchers.same; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class DataFrameDataExtractorTests extends ESTestCase { + + private static final String JOB_ID = "foo"; + + private Client client; + private List indices; + private ExtractedFields extractedFields; + private QueryBuilder query; + private int scrollSize; + private Map headers; + private ArgumentCaptor capturedClearScrollRequests; + private ActionFuture clearScrollFuture; + + @Before + public void setUpTests() { + ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + + indices = Arrays.asList("index-1", "index-2"); + query = QueryBuilders.matchAllQuery(); + extractedFields = new ExtractedFields(Arrays.asList( + ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.DOC_VALUE))); + scrollSize = 1000; + headers = Collections.emptyMap(); + + clearScrollFuture = mock(ActionFuture.class); + capturedClearScrollRequests = ArgumentCaptor.forClass(ClearScrollRequest.class); + when(client.execute(same(ClearScrollAction.INSTANCE), capturedClearScrollRequests.capture())).thenReturn(clearScrollFuture); + } + + public void testTwoPageExtraction() throws IOException { + TestExtractor dataExtractor = createExtractor(true); + + // First batch + SearchResponse response1 = createSearchResponse(Arrays.asList(1_1, 1_2, 1_3), Arrays.asList(2_1, 2_2, 2_3)); + dataExtractor.setNextResponse(response1); + + // Second batch + SearchResponse response2 = createSearchResponse(Arrays.asList(3_1), Arrays.asList(4_1)); + dataExtractor.setNextResponse(response2); + + // Third batch is empty + SearchResponse lastAndEmptyResponse = createEmptySearchResponse(); + dataExtractor.setNextResponse(lastAndEmptyResponse); + + assertThat(dataExtractor.hasNext(), is(true)); + + // First batch + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(3)); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); + assertThat(rows.get().get(1).getValues(), equalTo(new String[] {"12", "22"})); + assertThat(rows.get().get(2).getValues(), equalTo(new String[] {"13", "23"})); + assertThat(dataExtractor.hasNext(), is(true)); + + // Second batch + rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(1)); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"31", "41"})); + assertThat(dataExtractor.hasNext(), is(true)); + + // Third batch should return empty + rows = dataExtractor.next(); + assertThat(rows.isEmpty(), is(true)); + assertThat(dataExtractor.hasNext(), is(false)); + + // Now let's assert we're sending the expected search request + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("allowPartialSearchResults=false")); + assertThat(searchRequest, containsString("indices=[index-1,index-2]")); + assertThat(searchRequest, containsString("\"size\":1000")); + assertThat(searchRequest, containsString("\"query\":{\"match_all\":{\"boost\":1.0}}")); + assertThat(searchRequest, containsString("\"docvalue_fields\":[{\"field\":\"field_1\"},{\"field\":\"field_2\"}]")); + assertThat(searchRequest, containsString("\"_source\":{\"includes\":[],\"excludes\":[]}")); + assertThat(searchRequest, containsString("\"sort\":[{\"_id_copy\":{\"order\":\"asc\"}}]")); + + // Check continue scroll requests had correct ids + assertThat(dataExtractor.capturedContinueScrollIds.size(), equalTo(2)); + assertThat(dataExtractor.capturedContinueScrollIds.get(0), equalTo(response1.getScrollId())); + assertThat(dataExtractor.capturedContinueScrollIds.get(1), equalTo(response2.getScrollId())); + + // Check we cleared the scroll with the latest scroll id + List capturedClearScrollRequests = getCapturedClearScrollIds(); + assertThat(capturedClearScrollRequests.size(), equalTo(1)); + assertThat(capturedClearScrollRequests.get(0), equalTo(lastAndEmptyResponse.getScrollId())); + } + + public void testRecoveryFromErrorOnSearchAfterRetry() throws IOException { + TestExtractor dataExtractor = createExtractor(true); + + // First search will fail + dataExtractor.setNextResponse(createResponseWithShardFailures()); + + // Next one will succeed + SearchResponse response = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1)); + dataExtractor.setNextResponse(response); + + // Last one + SearchResponse lastAndEmptyResponse = createEmptySearchResponse(); + dataExtractor.setNextResponse(lastAndEmptyResponse); + + assertThat(dataExtractor.hasNext(), is(true)); + + // First batch expected as normally since we'll retry after the error + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(1)); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); + assertThat(dataExtractor.hasNext(), is(true)); + + // Next batch should return empty + rows = dataExtractor.next(); + assertThat(rows.isEmpty(), is(true)); + assertThat(dataExtractor.hasNext(), is(false)); + + // Check we cleared the scroll with the latest scroll id + List capturedClearScrollRequests = getCapturedClearScrollIds(); + assertThat(capturedClearScrollRequests.size(), equalTo(1)); + assertThat(capturedClearScrollRequests.get(0), equalTo(lastAndEmptyResponse.getScrollId())); + } + + public void testErrorOnSearchTwiceLeadsToFailure() { + TestExtractor dataExtractor = createExtractor(true); + + // First search will fail + dataExtractor.setNextResponse(createResponseWithShardFailures()); + // Next one fails again + dataExtractor.setNextResponse(createResponseWithShardFailures()); + + assertThat(dataExtractor.hasNext(), is(true)); + + expectThrows(RuntimeException.class, () -> dataExtractor.next()); + } + + public void testRecoveryFromErrorOnContinueScrollAfterRetry() throws IOException { + TestExtractor dataExtractor = createExtractor(true); + + // Search will succeed + SearchResponse response1 = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1)); + dataExtractor.setNextResponse(response1); + + // But the first continue scroll fails + dataExtractor.setNextResponse(createResponseWithShardFailures()); + + // The next one succeeds and we shall recover + SearchResponse response2 = createSearchResponse(Arrays.asList(1_2), Arrays.asList(2_2)); + dataExtractor.setNextResponse(response2); + + // Last one + SearchResponse lastAndEmptyResponse = createEmptySearchResponse(); + dataExtractor.setNextResponse(lastAndEmptyResponse); + + assertThat(dataExtractor.hasNext(), is(true)); + + // First batch expected as normally since we'll retry after the error + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(1)); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); + assertThat(dataExtractor.hasNext(), is(true)); + + // We get second batch as we retried after the error + rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(1)); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"12", "22"})); + assertThat(dataExtractor.hasNext(), is(true)); + + // Next batch should return empty + rows = dataExtractor.next(); + assertThat(rows.isEmpty(), is(true)); + assertThat(dataExtractor.hasNext(), is(false)); + + // Notice we've done two searches and two continues here + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(2)); + assertThat(dataExtractor.capturedContinueScrollIds.size(), equalTo(2)); + + // Check we cleared the scroll with the latest scroll id + List capturedClearScrollRequests = getCapturedClearScrollIds(); + assertThat(capturedClearScrollRequests.size(), equalTo(1)); + assertThat(capturedClearScrollRequests.get(0), equalTo(lastAndEmptyResponse.getScrollId())); + } + + public void testErrorOnContinueScrollTwiceLeadsToFailure() throws IOException { + TestExtractor dataExtractor = createExtractor(true); + + // Search will succeed + SearchResponse response1 = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1)); + dataExtractor.setNextResponse(response1); + + // But the first continue scroll fails + dataExtractor.setNextResponse(createResponseWithShardFailures()); + // As well as the second + dataExtractor.setNextResponse(createResponseWithShardFailures()); + + assertThat(dataExtractor.hasNext(), is(true)); + + // First batch expected as normally since we'll retry after the error + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(1)); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); + assertThat(dataExtractor.hasNext(), is(true)); + + // We get second batch as we retried after the error + expectThrows(RuntimeException.class, () -> dataExtractor.next()); + } + + private TestExtractor createExtractor(boolean includeSource) { + DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( + JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource); + return new TestExtractor(client, context); + } + + private SearchResponse createSearchResponse(List field1Values, List field2Values) { + assertThat(field1Values.size(), equalTo(field2Values.size())); + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.getScrollId()).thenReturn(randomAlphaOfLength(1000)); + List hits = new ArrayList<>(); + for (int i = 0; i < field1Values.size(); i++) { + SearchHit hit = new SearchHit(randomInt()); + Map fields = new HashMap<>(); + fields.put("field_1", new DocumentField("field_1", Collections.singletonList(field1Values.get(i)))); + fields.put("field_2", new DocumentField("field_2", Collections.singletonList(field2Values.get(i)))); + hit.fields(fields); + hits.add(hit); + } + SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[0]), new TotalHits(hits.size(), TotalHits.Relation.EQUAL_TO), 1); + when(searchResponse.getHits()).thenReturn(searchHits); + return searchResponse; + } + + private SearchResponse createEmptySearchResponse() { + return createSearchResponse(Collections.emptyList(), Collections.emptyList()); + } + + private SearchResponse createResponseWithShardFailures() { + SearchResponse searchResponse = mock(SearchResponse.class); + when(searchResponse.status()).thenReturn(RestStatus.OK); + when(searchResponse.getShardFailures()).thenReturn( + new ShardSearchFailure[] { new ShardSearchFailure(new RuntimeException("shard failed"))}); + when(searchResponse.getFailedShards()).thenReturn(1); + when(searchResponse.getScrollId()).thenReturn(randomAlphaOfLength(1000)); + return searchResponse; + } + + private List getCapturedClearScrollIds() { + return capturedClearScrollRequests.getAllValues().stream().map(r -> r.getScrollIds().get(0)).collect(Collectors.toList()); + } + + private static class TestExtractor extends DataFrameDataExtractor { + + private Queue responses = new LinkedList<>(); + private List capturedSearchRequests = new ArrayList<>(); + private List capturedContinueScrollIds = new ArrayList<>(); + + TestExtractor(Client client, DataFrameDataExtractorContext context) { + super(client, context); + } + + void setNextResponse(SearchResponse searchResponse) { + responses.add(searchResponse); + } + + @Override + protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) { + capturedSearchRequests.add(searchRequestBuilder); + SearchResponse searchResponse = responses.remove(); + if (searchResponse.getShardFailures() != null) { + throw new RuntimeException(searchResponse.getShardFailures()[0].getCause()); + } + return searchResponse; + } + + @Override + protected SearchResponse executeSearchScrollRequest(String scrollId) { + capturedContinueScrollIds.add(scrollId); + SearchResponse searchResponse = responses.remove(); + if (searchResponse.getShardFailures() != null) { + throw new RuntimeException(searchResponse.getShardFailures()[0].getCause()); + } + return searchResponse; + } + } +} From 3783f875f38707c83407fdfc171c9a994d314aac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Tue, 11 Jun 2019 17:38:33 +0200 Subject: [PATCH 62/67] Add missing argument to the instantiation of DataFrameTransformConfig (#43099) --- .../dataframe/integration/DataFrameTransformProgressIT.java | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java b/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java index 3bd8cbae28aeb..79773fc93c071 100644 --- a/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java +++ b/x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java @@ -170,6 +170,7 @@ public void testGetProgress() throws Exception { sourceConfig, destConfig, null, + null, pivotConfig, null); From c3c45c0837c1ab0181ffb69c2ff64e425bffa52b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Wed, 12 Jun 2019 13:24:20 +0200 Subject: [PATCH 63/67] [ML] Documentation for Data Frame Analytics high-level REST client (#42288) --- .../MlClientDocumentationIT.java | 555 ++++++++++++++++-- .../ml/delete-data-frame-analytics.asciidoc | 28 + .../ml/evaluate-data-frame.asciidoc | 45 ++ .../get-data-frame-analytics-stats.asciidoc | 34 ++ .../ml/get-data-frame-analytics.asciidoc | 34 ++ .../ml/put-data-frame-analytics.asciidoc | 115 ++++ .../ml/start-data-frame-analytics.asciidoc | 28 + .../ml/stop-data-frame-analytics.asciidoc | 28 + .../high-level/supported-apis.asciidoc | 14 + 9 files changed, 845 insertions(+), 36 deletions(-) create mode 100644 docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc create mode 100644 docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc create mode 100644 docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc create mode 100644 docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc create mode 100644 docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc create mode 100644 docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc create mode 100644 docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index fe7d04a4e0a8d..0203a3c855d14 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -39,6 +39,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -47,6 +48,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -61,6 +64,10 @@ import org.elasticsearch.client.ml.GetCalendarsResponse; import org.elasticsearch.client.ml.GetCategoriesRequest; import org.elasticsearch.client.ml.GetCategoriesResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetModelSnapshotsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; @@ -92,6 +99,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -101,8 +110,11 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -118,6 +130,21 @@ import org.elasticsearch.client.ml.datafeed.DatafeedStats; import org.elasticsearch.client.ml.datafeed.DatafeedUpdate; import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric.ConfusionMatrix; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; @@ -139,13 +166,18 @@ import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.tasks.TaskId; +import org.hamcrest.CoreMatchers; import org.junit.After; import java.io.IOException; @@ -870,18 +902,7 @@ public void testPreviewDatafeed() throws Exception { client.machineLearning().putJob(new PutJobRequest(job), RequestOptions.DEFAULT); String datafeedId = job.getId() + "-feed"; String indexName = "preview_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId, job.getId()) .setIndices(indexName) .build(); @@ -938,18 +959,7 @@ public void testStartDatafeed() throws Exception { client.machineLearning().putJob(new PutJobRequest(job), RequestOptions.DEFAULT); String datafeedId = job.getId() + "-feed"; String indexName = "start_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId, job.getId()) .setIndices(indexName) .build(); @@ -1067,18 +1077,7 @@ public void testGetDatafeedStats() throws Exception { client.machineLearning().putJob(new PutJobRequest(secondJob), RequestOptions.DEFAULT); String datafeedId1 = job.getId() + "-feed"; String indexName = "datafeed_stats_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId1, job.getId()) .setIndices(indexName) .build(); @@ -2802,6 +2801,455 @@ public void onFailure(Exception e) { } } + public void testGetDataFrameAnalytics() throws Exception { + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::get-data-frame-analytics-request + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::get-data-frame-analytics-request + + // tag::get-data-frame-analytics-execute + GetDataFrameAnalyticsResponse response = client.machineLearning().getDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::get-data-frame-analytics-execute + + // tag::get-data-frame-analytics-response + List configs = response.getAnalytics(); + // end::get-data-frame-analytics-response + + assertThat(configs.size(), equalTo(1)); + } + { + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::get-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(GetDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::get-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::get-data-frame-analytics-execute-async + client.machineLearning().getDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::get-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testGetDataFrameAnalyticsStats() throws Exception { + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::get-data-frame-analytics-stats-request + GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config"); // <1> + // end::get-data-frame-analytics-stats-request + + // tag::get-data-frame-analytics-stats-execute + GetDataFrameAnalyticsStatsResponse response = + client.machineLearning().getDataFrameAnalyticsStats(request, RequestOptions.DEFAULT); + // end::get-data-frame-analytics-stats-execute + + // tag::get-data-frame-analytics-stats-response + List stats = response.getAnalyticsStats(); + // end::get-data-frame-analytics-stats-response + + assertThat(stats.size(), equalTo(1)); + } + { + GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config"); + + // tag::get-data-frame-analytics-stats-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(GetDataFrameAnalyticsStatsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::get-data-frame-analytics-stats-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::get-data-frame-analytics-stats-execute-async + client.machineLearning().getDataFrameAnalyticsStatsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::get-data-frame-analytics-stats-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testPutDataFrameAnalytics() throws Exception { + RestHighLevelClient client = highLevelClient(); + { + // tag::put-data-frame-analytics-query-config + QueryConfig queryConfig = new QueryConfig(new MatchAllQueryBuilder()); + // end::put-data-frame-analytics-query-config + + // tag::put-data-frame-analytics-source-config + DataFrameAnalyticsSource sourceConfig = DataFrameAnalyticsSource.builder() // <1> + .setIndex("put-test-source-index") // <2> + .setQueryConfig(queryConfig) // <3> + .build(); + // end::put-data-frame-analytics-source-config + + // tag::put-data-frame-analytics-dest-config + DataFrameAnalyticsDest destConfig = DataFrameAnalyticsDest.builder() // <1> + .setIndex("put-test-dest-index") // <2> + .build(); + // end::put-data-frame-analytics-dest-config + + // tag::put-data-frame-analytics-analysis-default + DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1> + // end::put-data-frame-analytics-analysis-default + + // tag::put-data-frame-analytics-analysis-customized + DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1> + .setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2> + .setNNeighbors(5) // <3> + .build(); + // end::put-data-frame-analytics-analysis-customized + + // tag::put-data-frame-analytics-analyzed-fields + FetchSourceContext analyzedFields = + new FetchSourceContext( + true, + new String[] { "included_field_1", "included_field_2" }, + new String[] { "excluded_field" }); + // end::put-data-frame-analytics-analyzed-fields + + // tag::put-data-frame-analytics-config + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder("my-analytics-config") // <1> + .setSource(sourceConfig) // <2> + .setDest(destConfig) // <3> + .setAnalysis(outlierDetection) // <4> + .setAnalyzedFields(analyzedFields) // <5> + .setModelMemoryLimit(new ByteSizeValue(5, ByteSizeUnit.MB)) // <6> + .build(); + // end::put-data-frame-analytics-config + + // tag::put-data-frame-analytics-request + PutDataFrameAnalyticsRequest request = new PutDataFrameAnalyticsRequest(config); // <1> + // end::put-data-frame-analytics-request + + // tag::put-data-frame-analytics-execute + PutDataFrameAnalyticsResponse response = client.machineLearning().putDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::put-data-frame-analytics-execute + + // tag::put-data-frame-analytics-response + DataFrameAnalyticsConfig createdConfig = response.getConfig(); + // end::put-data-frame-analytics-response + + assertThat(createdConfig.getId(), equalTo("my-analytics-config")); + } + { + PutDataFrameAnalyticsRequest request = new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG); + // tag::put-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(PutDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::put-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + final CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::put-data-frame-analytics-execute-async + client.machineLearning().putDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::put-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testDeleteDataFrameAnalytics() throws Exception { + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::delete-data-frame-analytics-request + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::delete-data-frame-analytics-request + + // tag::delete-data-frame-analytics-execute + AcknowledgedResponse response = client.machineLearning().deleteDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::delete-data-frame-analytics-execute + + // tag::delete-data-frame-analytics-response + boolean acknowledged = response.isAcknowledged(); + // end::delete-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::delete-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(AcknowledgedResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::delete-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::delete-data-frame-analytics-execute-async + client.machineLearning().deleteDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::delete-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testStartDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + highLevelClient().index( + new IndexRequest(DF_ANALYTICS_CONFIG.getSource().getIndex()).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::start-data-frame-analytics-request + StartDataFrameAnalyticsRequest request = new StartDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::start-data-frame-analytics-request + + // tag::start-data-frame-analytics-execute + AcknowledgedResponse response = client.machineLearning().startDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::start-data-frame-analytics-execute + + // tag::start-data-frame-analytics-response + boolean acknowledged = response.isAcknowledged(); + // end::start-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + { + StartDataFrameAnalyticsRequest request = new StartDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::start-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(AcknowledgedResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::start-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::start-data-frame-analytics-execute-async + client.machineLearning().startDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::start-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + } + + public void testStopDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + highLevelClient().index( + new IndexRequest(DF_ANALYTICS_CONFIG.getSource().getIndex()).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::stop-data-frame-analytics-request + StopDataFrameAnalyticsRequest request = new StopDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::stop-data-frame-analytics-request + + // tag::stop-data-frame-analytics-execute + StopDataFrameAnalyticsResponse response = client.machineLearning().stopDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::stop-data-frame-analytics-execute + + // tag::stop-data-frame-analytics-response + boolean acknowledged = response.isStopped(); + // end::stop-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + { + StopDataFrameAnalyticsRequest request = new StopDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::stop-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(StopDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::stop-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::stop-data-frame-analytics-execute-async + client.machineLearning().stopDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::stop-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + } + + public void testEvaluateDataFrame() throws Exception { + String indexName = "evaluate-test-index"; + CreateIndexRequest createIndexRequest = + new CreateIndexRequest(indexName) + .mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("label") + .field("type", "keyword") + .endObject() + .startObject("p") + .field("type", "double") + .endObject() + .endObject() + .endObject()); + BulkRequest bulkRequest = + new BulkRequest(indexName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.1)) // #0 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.2)) // #1 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.3)) // #2 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.4)) // #3 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.7)) // #4 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.2)) // #5 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.3)) // #6 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.4)) // #7 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.8)) // #8 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.9)); // #9 + RestHighLevelClient client = highLevelClient(); + client.indices().create(createIndexRequest, RequestOptions.DEFAULT); + client.bulk(bulkRequest, RequestOptions.DEFAULT); + { + // tag::evaluate-data-frame-request + EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1> + indexName, // <2> + new BinarySoftClassification( // <3> + "label", // <4> + "p", // <5> + // Evaluation metrics // <6> + PrecisionMetric.at(0.4, 0.5, 0.6), // <7> + RecallMetric.at(0.5, 0.7), // <8> + ConfusionMatrixMetric.at(0.5), // <9> + AucRocMetric.withCurve())); // <10> + // end::evaluate-data-frame-request + + // tag::evaluate-data-frame-execute + EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT); + // end::evaluate-data-frame-execute + + // tag::evaluate-data-frame-response + List metrics = response.getMetrics(); // <1> + + PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <2> + double precision = precisionResult.getScoreByThreshold("0.4"); // <3> + + ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <4> + ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <5> + // end::evaluate-data-frame-response + + assertThat( + metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()), + containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME)); + assertThat(precision, closeTo(0.6, 1e-9)); + assertThat(confusionMatrix.getTruePositives(), CoreMatchers.equalTo(2L)); // docs #8 and #9 + assertThat(confusionMatrix.getFalsePositives(), CoreMatchers.equalTo(1L)); // doc #4 + assertThat(confusionMatrix.getTrueNegatives(), CoreMatchers.equalTo(4L)); // docs #0, #1, #2 and #3 + assertThat(confusionMatrix.getFalseNegatives(), CoreMatchers.equalTo(3L)); // docs #5, #6 and #7 + } + { + EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( + indexName, + new BinarySoftClassification( + "label", + "p", + PrecisionMetric.at(0.4, 0.5, 0.6), + RecallMetric.at(0.5, 0.7), + ConfusionMatrixMetric.at(0.5), + AucRocMetric.withCurve())); + + // tag::evaluate-data-frame-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(EvaluateDataFrameResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::evaluate-data-frame-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::evaluate-data-frame-execute-async + client.machineLearning().evaluateDataFrameAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::evaluate-data-frame-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } public void testCreateFilter() throws Exception { RestHighLevelClient client = highLevelClient(); @@ -3140,4 +3588,39 @@ private String createFilter(RestHighLevelClient client) throws IOException { assertThat(createdFilter.getId(), equalTo("my_safe_domains")); return createdFilter.getId(); } + + private void createIndex(String indexName) throws IOException { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("timestamp") + .field("type", "date") + .endObject() + .startObject("total") + .field("type", "long") + .endObject() + .endObject() + .endObject()); + highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + } + + private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOException { + GetDataFrameAnalyticsStatsResponse statsResponse = + highLevelClient().machineLearning().getDataFrameAnalyticsStats( + new GetDataFrameAnalyticsStatsRequest(configId), RequestOptions.DEFAULT); + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + return stats.getState(); + } + + private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG = + DataFrameAnalyticsConfig.builder("my-analytics-config") + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("put-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("put-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); } diff --git a/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..2e5ade37107cf --- /dev/null +++ b/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: delete-data-frame-analytics +:request: DeleteDataFrameAnalyticsRequest +:response: AcknowledgedResponse +-- +[id="{upid}-{api}"] +=== Delete Data Frame Analytics API + +The Delete Data Frame Analytics API is used to delete an existing {dataframe-analytics-config}. +The API accepts a +{request}+ object as a request and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Delete Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-analytics-config} deletion. diff --git a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc new file mode 100644 index 0000000000000..660603d2e38e7 --- /dev/null +++ b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc @@ -0,0 +1,45 @@ +-- +:api: evaluate-data-frame +:request: EvaluateDataFrameRequest +:response: EvaluateDataFrameResponse +-- +[id="{upid}-{api}"] +=== Evaluate Data Frame API + +The Evaluate Data Frame API is used to evaluate an ML algorithm that ran on a {dataframe}. +The API accepts an +{request}+ object and returns an +{response}+. + +[id="{upid}-{api}-request"] +==== Evaluate Data Frame Request + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new evaluation request +<2> Reference to an existing index +<3> Kind of evaluation to perform +<4> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false +<5> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive +<6> The remaining parameters are the metrics to be calculated based on the two fields described above. +<7> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6 +<8> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7 +<9> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5 +<10> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested evaluation metrics. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- +<1> Fetching all the calculated metrics results +<2> Fetching precision metric by name +<3> Fetching precision at a given (0.4) threshold +<4> Fetching confusion matrix metric by name +<5> Fetching confusion matrix at a given (0.5) threshold \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc b/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc new file mode 100644 index 0000000000000..e1047e9b3e002 --- /dev/null +++ b/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc @@ -0,0 +1,34 @@ +-- +:api: get-data-frame-analytics-stats +:request: GetDataFrameAnalyticsStatsRequest +:response: GetDataFrameAnalyticsStatsResponse +-- +[id="{upid}-{api}"] +=== Get Data Frame Analytics Stats API + +The Get Data Frame Analytics Stats API is used to read the operational statistics of one or more {dataframe-analytics-config}s. +The API accepts a +{request}+ object and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Get Data Frame Analytics Stats Request + +A +{request}+ requires either a {dataframe-analytics-config} id, a comma separated list of ids or +the special wildcard `_all` to get the statistics for all {dataframe-analytics-config}s + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new GET Stats request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested {dataframe-analytics-config} statistics. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..c6d368efbcae9 --- /dev/null +++ b/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc @@ -0,0 +1,34 @@ +-- +:api: get-data-frame-analytics +:request: GetDataFrameAnalyticsRequest +:response: GetDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Get Data Frame Analytics API + +The Get Data Frame Analytics API is used to get one or more {dataframe-analytics-config}s. +The API accepts a +{request}+ object and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Get Data Frame Analytics Request + +A +{request}+ requires either a {dataframe-analytics-config} id, a comma separated list of ids or +the special wildcard `_all` to get all {dataframe-analytics-config}s. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new GET request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested {dataframe-analytics-config}s. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..05fbd5bc3922a --- /dev/null +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -0,0 +1,115 @@ +-- +:api: put-data-frame-analytics +:request: PutDataFrameAnalyticsRequest +:response: PutDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Put Data Frame Analytics API + +The Put Data Frame Analytics API is used to create a new {dataframe-analytics-config}. +The API accepts a +{request}+ object as a request and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Put Data Frame Analytics Request + +A +{request}+ requires the following argument: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> The configuration of the {dataframe-job} to create + +[id="{upid}-{api}-config"] +==== Data Frame Analytics Configuration + +The `DataFrameAnalyticsConfig` object contains all the details about the {dataframe-job} +configuration and contains the following arguments: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-config] +-------------------------------------------------- +<1> The {dataframe-analytics-config} id +<2> The source index and query from which to gather data +<3> The destination index +<4> The analysis to be performed +<5> The fields to be included in / excluded from the analysis +<6> The memory limit for the model created as part of the analysis process + +[id="{upid}-{api}-query-config"] + +==== SourceConfig + +The index and the query from which to collect data. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-source-config] +-------------------------------------------------- +<1> Constructing a new DataFrameAnalyticsSource +<2> The source index +<3> The query from which to gather the data. If query is not set, a `match_all` query is used by default. + +===== QueryConfig + +The query with which to select data from the source. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-query-config] +-------------------------------------------------- + +==== DestinationConfig + +The index to which data should be written by the {dataframe-job}. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-dest-config] +-------------------------------------------------- +<1> Constructing a new DataFrameAnalyticsDest +<2> The destination index + +==== Analysis + +The analysis to be performed. +Currently, only one analysis is supported: +OutlierDetection+. + ++OutlierDetection+ analysis can be created in one of two ways: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analysis-default] +-------------------------------------------------- +<1> Constructing a new OutlierDetection object with default strategy to determine outliers + +or +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analysis-customized] +-------------------------------------------------- +<1> Constructing a new OutlierDetection object +<2> The method used to perform the analysis +<3> Number of neighbors taken into account during analysis + +==== Analyzed fields + +FetchContext object containing fields to be included in / excluded from the analysis + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analyzed-fields] +-------------------------------------------------- + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the newly created {dataframe-analytics-config}. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..610607daba1f8 --- /dev/null +++ b/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: start-data-frame-analytics +:request: StartDataFrameAnalyticsRequest +:response: AcknowledgedResponse +-- +[id="{upid}-{api}"] +=== Start Data Frame Analytics API + +The Start Data Frame Analytics API is used to start an existing {dataframe-analytics-config}. +It accepts a +{request}+ object and responds with a +{response}+ object. + +[id="{upid}-{api}-request"] +==== Start Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new start request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-job} has started. \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..243c075e18b03 --- /dev/null +++ b/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: stop-data-frame-analytics +:request: StopDataFrameAnalyticsRequest +:response: StopDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Stop Data Frame Analytics API + +The Stop Data Frame Analytics API is used to stop a running {dataframe-analytics-config}. +It accepts a +{request}+ object and responds with a +{response}+ object. + +[id="{upid}-{api}-request"] +==== Stop Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new stop request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-job} has stopped. \ No newline at end of file diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc index 4e28efc2941db..21ebdfab65155 100644 --- a/docs/java-rest/high-level/supported-apis.asciidoc +++ b/docs/java-rest/high-level/supported-apis.asciidoc @@ -285,6 +285,13 @@ The Java High Level REST Client supports the following Machine Learning APIs: * <<{upid}-put-calendar-job>> * <<{upid}-delete-calendar-job>> * <<{upid}-delete-calendar>> +* <<{upid}-get-data-frame-analytics>> +* <<{upid}-get-data-frame-analytics-stats>> +* <<{upid}-put-data-frame-analytics>> +* <<{upid}-delete-data-frame-analytics>> +* <<{upid}-start-data-frame-analytics>> +* <<{upid}-stop-data-frame-analytics>> +* <<{upid}-evaluate-data-frame>> * <<{upid}-put-filter>> * <<{upid}-get-filters>> * <<{upid}-update-filter>> @@ -329,6 +336,13 @@ include::ml/delete-calendar-event.asciidoc[] include::ml/put-calendar-job.asciidoc[] include::ml/delete-calendar-job.asciidoc[] include::ml/delete-calendar.asciidoc[] +include::ml/get-data-frame-analytics.asciidoc[] +include::ml/get-data-frame-analytics-stats.asciidoc[] +include::ml/put-data-frame-analytics.asciidoc[] +include::ml/delete-data-frame-analytics.asciidoc[] +include::ml/start-data-frame-analytics.asciidoc[] +include::ml/stop-data-frame-analytics.asciidoc[] +include::ml/evaluate-data-frame.asciidoc[] include::ml/put-filter.asciidoc[] include::ml/get-filters.asciidoc[] include::ml/update-filter.asciidoc[] From cf8a8bdf0cebfdfe360a3417c3e96e857b4967d4 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 14 Jun 2019 15:18:26 +0300 Subject: [PATCH 64/67] [FEATURE][ML] Specify stability in df-analytics apis --- .../rest-api-spec/api/ml.delete_data_frame_analytics.json | 1 + .../test/resources/rest-api-spec/api/ml.evaluate_data_frame.json | 1 + .../resources/rest-api-spec/api/ml.get_data_frame_analytics.json | 1 + .../rest-api-spec/api/ml.get_data_frame_analytics_stats.json | 1 + .../resources/rest-api-spec/api/ml.put_data_frame_analytics.json | 1 + .../rest-api-spec/api/ml.start_data_frame_analytics.json | 1 + .../rest-api-spec/api/ml.stop_data_frame_analytics.json | 1 + 7 files changed, 7 insertions(+) diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json index a09259fabb8be..cf4d0ed4ec7f5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.delete_data_frame_analytics.json @@ -1,5 +1,6 @@ { "ml.delete_data_frame_analytics": { + "stability": "experimental", "methods": [ "DELETE" ], "url": { "path": "/_ml/data_frame/analytics/{id}", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json index 1a4859c796095..fb6e5ca5156cf 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.evaluate_data_frame.json @@ -1,5 +1,6 @@ { "ml.evaluate_data_frame": { + "stability": "experimental", "methods": [ "POST" ], "url": { "path": "/_ml/data_frame/_evaluate", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json index 491ced8b541da..dfb8de1310d04 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics.json @@ -1,5 +1,6 @@ { "ml.get_data_frame_analytics": { + "stability": "experimental", "methods": [ "GET"], "url": { "path": "/_ml/data_frame/analytics/{id}", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json index 87ffe6c0fd722..4ae7b754403c5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.get_data_frame_analytics_stats.json @@ -1,5 +1,6 @@ { "ml.get_data_frame_analytics_stats": { + "stability": "experimental", "methods": [ "GET"], "url": { "path": "/_ml/data_frame/analytics/{id}/_stats", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json index 1f3183920aca5..5cee69e3ab951 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_data_frame_analytics.json @@ -1,5 +1,6 @@ { "ml.put_data_frame_analytics": { + "stability": "experimental", "methods": [ "PUT" ], "url": { "path": "/_ml/data_frame/analytics/{id}", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json index dfe0cac2f7b67..3436623c61b55 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.start_data_frame_analytics.json @@ -1,5 +1,6 @@ { "ml.start_data_frame_analytics": { + "stability": "experimental", "methods": [ "POST" ], "url": { "path": "/_ml/data_frame/analytics/{id}/_start", diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json index cc95def45fce6..962e4e391a045 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.stop_data_frame_analytics.json @@ -1,5 +1,6 @@ { "ml.stop_data_frame_analytics": { + "stability": "experimental", "methods": [ "POST" ], "url": { "path": "/_ml/data_frame/analytics/{id}/_stop", From eced35342cd3a26b524718eed1ca35c738a5ce9d Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 14 Jun 2019 17:45:39 +0300 Subject: [PATCH 65/67] [FEATURE][ML] Fetch from source when fields are more then docvalue limit (#43204) --- .../integration/RunDataFrameAnalyticsIT.java | 69 +++++++++++++++ .../extractor/fields/ExtractedField.java | 30 +++++++ .../extractor/DataFrameDataExtractor.java | 19 +++- .../DataFrameDataExtractorFactory.java | 74 ++++++++++++---- .../extractor/ExtractedFieldsDetector.java | 20 ++++- .../DataFrameDataExtractorTests.java | 65 ++++++++++++-- .../ExtractedFieldsDetectorTests.java | 86 ++++++++++++++++--- 7 files changed, 326 insertions(+), 37 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java index 1f3899939938e..7d29cbc345789 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java @@ -6,6 +6,8 @@ package org.elasticsearch.xpack.ml.integration; import org.elasticsearch.action.admin.indices.exists.indices.IndicesExistsRequest; +import org.elasticsearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.get.GetResponse; @@ -13,6 +15,8 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; @@ -147,6 +151,71 @@ public void testOutlierDetectionWithEnoughDocumentsToScroll() throws Exception { assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount)); } + public void testOutlierDetectionWithMoreFieldsThanDocValueFieldLimit() throws Exception { + String sourceIndex = "test-outlier-detection-with-more-fields-than-docvalue-limit"; + + client().admin().indices().prepareCreate(sourceIndex).get(); + + GetSettingsRequest getSettingsRequest = new GetSettingsRequest(); + getSettingsRequest.indices(sourceIndex); + getSettingsRequest.names(IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey()); + getSettingsRequest.includeDefaults(true); + + GetSettingsResponse docValueLimitSetting = client().admin().indices().getSettings(getSettingsRequest).actionGet(); + int docValueLimit = IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.get( + docValueLimitSetting.getIndexToSettings().values().iterator().next().value); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + for (int i = 0; i < 100; i++) { + + StringBuilder source = new StringBuilder("{"); + for (int fieldCount = 0; fieldCount < docValueLimit + 1; fieldCount++) { + source.append("\"field_").append(fieldCount).append("\":").append(randomDouble()); + if (fieldCount < docValueLimit) { + source.append(","); + } + } + source.append("}"); + + IndexRequest indexRequest = new IndexRequest(sourceIndex); + indexRequest.source(source.toString(), XContentType.JSON); + bulkRequestBuilder.add(indexRequest); + } + BulkResponse bulkResponse = bulkRequestBuilder.get(); + if (bulkResponse.hasFailures()) { + fail("Failed to index data: " + bulkResponse.buildFailureMessage()); + } + + String id = "test_outlier_detection_with_more_fields_than_docvalue_limit"; + DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(id, sourceIndex, null); + registerAnalytics(config); + putAnalytics(config); + + assertState(id, DataFrameAnalyticsState.STOPPED); + + startAnalytics(id); + waitUntilAnalyticsIsStopped(id); + + SearchResponse sourceData = client().prepareSearch(sourceIndex).get(); + for (SearchHit hit : sourceData.getHits()) { + GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get(); + assertThat(destDocGetResponse.isExists(), is(true)); + Map sourceDoc = hit.getSourceAsMap(); + Map destDoc = destDocGetResponse.getSource(); + for (String field : sourceDoc.keySet()) { + assertThat(destDoc.containsKey(field), is(true)); + assertThat(destDoc.get(field), equalTo(sourceDoc.get(field))); + } + assertThat(destDoc.containsKey("ml"), is(true)); + Map resultsObject = (Map) destDoc.get("ml"); + assertThat(resultsObject.containsKey("outlier_score"), is(true)); + double outlierScore = (double) resultsObject.get("outlier_score"); + assertThat(outlierScore, allOf(greaterThanOrEqualTo(0.0), lessThanOrEqualTo(100.0))); + } + } + public void testStopOutlierDetectionWithEnoughDocumentsToScroll() { String sourceIndex = "test-outlier-detection-with-enough-docs-to-scroll"; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java index 49642aaeb23f7..afd53ed258426 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/fields/ExtractedField.java @@ -58,6 +58,8 @@ public ExtractionMethod getExtractionMethod() { public abstract Object[] value(SearchHit hit); + public abstract boolean supportsFromSource(); + public String getDocValueFormat() { return null; } @@ -93,6 +95,14 @@ public static ExtractedField newField(String alias, String name, ExtractionMetho } } + public ExtractedField newFromSource() { + if (supportsFromSource()) { + return new FromSource(alias, name); + } + throw new IllegalStateException("Field (alias [" + alias + "], name [" + name + "]) should be extracted via [" + + extractionMethod + "] and cannot be extracted from source"); + } + private static class FromFields extends ExtractedField { FromFields(String alias, String name, ExtractionMethod extractionMethod) { @@ -108,6 +118,11 @@ public Object[] value(SearchHit hit) { } return new Object[0]; } + + @Override + public boolean supportsFromSource() { + return getExtractionMethod() == ExtractionMethod.DOC_VALUE; + } } private static class GeoShapeField extends FromSource { @@ -195,6 +210,11 @@ private String handleString(String geoString) { throw new IllegalArgumentException("Unexpected value for a geo_point field: " + geoString); } } + + @Override + public boolean supportsFromSource() { + return false; + } } private static class TimeField extends FromFields { @@ -223,6 +243,11 @@ public Object[] value(SearchHit hit) { public String getDocValueFormat() { return EPOCH_MILLIS_FORMAT; } + + @Override + public boolean supportsFromSource() { + return false; + } } private static class FromSource extends ExtractedField { @@ -257,6 +282,11 @@ public Object[] value(SearchHit hit) { return new Object[0]; } + @Override + public boolean supportsFromSource() { + return true; + } + @SuppressWarnings("unchecked") private static Map getNextLevel(Map source, String key) { Object nextLevel = source.get(key); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java index 7b8452f635f9e..59cd78b4cc6fa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.fetch.StoredFieldsContext; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; @@ -128,8 +129,8 @@ private SearchRequestBuilder buildSearchRequest() { .addSort(DataFrameAnalyticsFields.ID, SortOrder.ASC) .setIndices(context.indices) .setSize(context.scrollSize) - .setQuery(context.query) - .setFetchSource(context.includeSource); + .setQuery(context.query); + setFetchSource(searchRequestBuilder); for (ExtractedField docValueField : context.extractedFields.getDocValueFields()) { searchRequestBuilder.addDocValueField(docValueField.getName(), docValueField.getDocValueFormat()); @@ -138,6 +139,20 @@ private SearchRequestBuilder buildSearchRequest() { return searchRequestBuilder; } + private void setFetchSource(SearchRequestBuilder searchRequestBuilder) { + if (context.includeSource) { + searchRequestBuilder.setFetchSource(true); + } else { + String[] sourceFields = context.extractedFields.getSourceFields(); + if (sourceFields.length == 0) { + searchRequestBuilder.setFetchSource(false); + searchRequestBuilder.storedFields(StoredFieldsContext._NONE_); + } else { + searchRequestBuilder.setFetchSource(sourceFields, null); + } + } + } + private List processSearchResponse(SearchResponse searchResponse) throws IOException { scrollId = searchResponse.getScrollId(); if (searchResponse.getHits().getHits().length == 0) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java index f7fc0faf0b011..baf77c420c5cb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java @@ -5,21 +5,29 @@ */ package org.elasticsearch.xpack.ml.dataframe.extractor; +import com.carrotsearch.hppc.cursors.ObjectObjectCursor; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.indices.settings.get.GetSettingsRequest; +import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesRequest; import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.client.Client; +import org.elasticsearch.common.collect.ImmutableOpenMap; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; import java.util.Arrays; +import java.util.Iterator; import java.util.Map; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; public class DataFrameDataExtractorFactory { @@ -96,29 +104,65 @@ private static void validateIndexAndExtractFields(Client client, DataFrameAnalyticsConfig config, boolean isTaskRestarting, ActionListener listener) { - // Step 2. Extract fields (if possible) and notify listener + AtomicInteger docValueFieldsLimitHolder = new AtomicInteger(); + + // Step 3. Extract fields (if possible) and notify listener ActionListener fieldCapabilitiesHandler = ActionListener.wrap( - fieldCapabilitiesResponse -> listener.onResponse( - new ExtractedFieldsDetector(index, config, isTaskRestarting, fieldCapabilitiesResponse).detect()), + fieldCapabilitiesResponse -> listener.onResponse(new ExtractedFieldsDetector(index, config, isTaskRestarting, + docValueFieldsLimitHolder.get(), fieldCapabilitiesResponse).detect()), + listener::onFailure + ); + + // Step 2. Get field capabilities necessary to build the information of how to extract fields + ActionListener docValueFieldsLimitListener = ActionListener.wrap( + docValueFieldsLimit -> { + docValueFieldsLimitHolder.set(docValueFieldsLimit); + + FieldCapabilitiesRequest fieldCapabilitiesRequest = new FieldCapabilitiesRequest(); + fieldCapabilitiesRequest.indices(index); + fieldCapabilitiesRequest.fields("*"); + ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, () -> { + client.execute(FieldCapabilitiesAction.INSTANCE, fieldCapabilitiesRequest, fieldCapabilitiesHandler); + // This response gets discarded - the listener handles the real response + return null; + }); + }, + listener::onFailure + ); + + // Step 1. Get doc value fields limit + getDocValueFieldsLimit(client, index, docValueFieldsLimitListener); + } + + private static void getDocValueFieldsLimit(Client client, String index, ActionListener docValueFieldsLimitListener) { + ActionListener settingsListener = ActionListener.wrap(getSettingsResponse -> { + Integer minDocValueFieldsLimit = Integer.MAX_VALUE; + + ImmutableOpenMap indexToSettings = getSettingsResponse.getIndexToSettings(); + Iterator> iterator = indexToSettings.iterator(); + while (iterator.hasNext()) { + ObjectObjectCursor indexSettings = iterator.next(); + Integer indexMaxDocValueFields = IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.get(indexSettings.value); + if (indexMaxDocValueFields < minDocValueFieldsLimit) { + minDocValueFieldsLimit = indexMaxDocValueFields; + } + } + docValueFieldsLimitListener.onResponse(minDocValueFieldsLimit); + }, e -> { if (e instanceof IndexNotFoundException) { - listener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " + docValueFieldsLimitListener.onFailure(new ResourceNotFoundException("cannot retrieve data because index " + ((IndexNotFoundException) e).getIndex() + " does not exist")); } else { - listener.onFailure(e); + docValueFieldsLimitListener.onFailure(e); } } ); - // Step 1. Get field capabilities necessary to build the information of how to extract fields - FieldCapabilitiesRequest fieldCapabilitiesRequest = new FieldCapabilitiesRequest(); - fieldCapabilitiesRequest.indices(index); - fieldCapabilitiesRequest.fields("*"); - ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client, () -> { - client.execute(FieldCapabilitiesAction.INSTANCE, fieldCapabilitiesRequest, fieldCapabilitiesHandler); - // This response gets discarded - the listener handles the real response - return null; - }); + GetSettingsRequest getSettingsRequest = new GetSettingsRequest(); + getSettingsRequest.indices(index); + getSettingsRequest.includeDefaults(true); + getSettingsRequest.names(IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey()); + client.admin().indices().getSettings(getSettingsRequest, settingsListener); } - } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index 96f0181b1416c..b36fc6f182a06 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; import org.elasticsearch.common.Strings; import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.mapper.NumberFieldMapper; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -57,13 +58,15 @@ public class ExtractedFieldsDetector { private final String index; private final DataFrameAnalyticsConfig config; private final boolean isTaskRestarting; + private final int docValueFieldsLimit; private final FieldCapabilitiesResponse fieldCapabilitiesResponse; - ExtractedFieldsDetector(String index, DataFrameAnalyticsConfig config, boolean isTaskRestarting, + ExtractedFieldsDetector(String index, DataFrameAnalyticsConfig config, boolean isTaskRestarting, int docValueFieldsLimit, FieldCapabilitiesResponse fieldCapabilitiesResponse) { this.index = Objects.requireNonNull(index); this.config = Objects.requireNonNull(config); this.isTaskRestarting = isTaskRestarting; + this.docValueFieldsLimit = docValueFieldsLimit; this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse); } @@ -86,6 +89,14 @@ public ExtractedFields detect() { if (extractedFields.getAllFields().isEmpty()) { throw ExceptionsHelper.badRequestException("No compatible fields could be detected in index [{}]", index); } + if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { + extractedFields = fetchFromSourceIfSupported(extractedFields); + if (extractedFields.getDocValueFields().size() > docValueFieldsLimit) { + throw ExceptionsHelper.badRequestException("[{}] fields must be retrieved from doc_values but the limit is [{}]; " + + "please adjust the index level setting [{}]", extractedFields.getDocValueFields().size(), docValueFieldsLimit, + IndexSettings.MAX_DOCVALUE_FIELDS_SEARCH_SETTING.getKey()); + } + } return extractedFields; } @@ -141,4 +152,11 @@ private void includeAndExcludeFields(Set fields, String index) { } } + private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFields) { + List adjusted = new ArrayList<>(extractedFields.getAllFields().size()); + for (ExtractedField field : extractedFields.getDocValueFields()) { + adjusted.add(field.supportsFromSource() ? field.newFromSource() : field); + } + return new ExtractedFields(adjusted); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index f6547e1e6e583..6b0e88d759b81 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; -import org.elasticsearch.common.document.DocumentField; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.query.QueryBuilder; @@ -26,6 +25,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField; import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedFields; +import org.elasticsearch.xpack.ml.test.SearchHitBuilder; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -33,7 +33,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -262,6 +261,58 @@ public void testErrorOnContinueScrollTwiceLeadsToFailure() throws IOException { expectThrows(RuntimeException.class, () -> dataExtractor.next()); } + public void testIncludeSourceIsFalseAndNoSourceFields() throws IOException { + TestExtractor dataExtractor = createExtractor(false); + + SearchResponse response = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1)); + dataExtractor.setNextResponse(response); + dataExtractor.setNextResponse(createEmptySearchResponse()); + + assertThat(dataExtractor.hasNext(), is(true)); + + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(1)); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); + assertThat(dataExtractor.hasNext(), is(true)); + + assertThat(dataExtractor.next().isEmpty(), is(true)); + assertThat(dataExtractor.hasNext(), is(false)); + + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"docvalue_fields\":[{\"field\":\"field_1\"},{\"field\":\"field_2\"}]")); + assertThat(searchRequest, containsString("\"_source\":false")); + } + + public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOException { + extractedFields = new ExtractedFields(Arrays.asList( + ExtractedField.newField("field_1", ExtractedField.ExtractionMethod.DOC_VALUE), + ExtractedField.newField("field_2", ExtractedField.ExtractionMethod.SOURCE))); + + TestExtractor dataExtractor = createExtractor(false); + + SearchResponse response = createSearchResponse(Arrays.asList(1_1), Arrays.asList(2_1)); + dataExtractor.setNextResponse(response); + dataExtractor.setNextResponse(createEmptySearchResponse()); + + assertThat(dataExtractor.hasNext(), is(true)); + + Optional> rows = dataExtractor.next(); + assertThat(rows.isPresent(), is(true)); + assertThat(rows.get().size(), equalTo(1)); + assertThat(rows.get().get(0).getValues(), equalTo(new String[] {"11", "21"})); + assertThat(dataExtractor.hasNext(), is(true)); + + assertThat(dataExtractor.next().isEmpty(), is(true)); + assertThat(dataExtractor.hasNext(), is(false)); + + assertThat(dataExtractor.capturedSearchRequests.size(), equalTo(1)); + String searchRequest = dataExtractor.capturedSearchRequests.get(0).request().toString().replaceAll("\\s", ""); + assertThat(searchRequest, containsString("\"docvalue_fields\":[{\"field\":\"field_1\"}]")); + assertThat(searchRequest, containsString("\"_source\":{\"includes\":[\"field_2\"],\"excludes\":[]}")); + } + private TestExtractor createExtractor(boolean includeSource) { DataFrameDataExtractorContext context = new DataFrameDataExtractorContext( JOB_ID, extractedFields, indices, query, scrollSize, headers, includeSource); @@ -275,11 +326,11 @@ private SearchResponse createSearchResponse(List field1Values, List hits = new ArrayList<>(); for (int i = 0; i < field1Values.size(); i++) { SearchHit hit = new SearchHit(randomInt()); - Map fields = new HashMap<>(); - fields.put("field_1", new DocumentField("field_1", Collections.singletonList(field1Values.get(i)))); - fields.put("field_2", new DocumentField("field_2", Collections.singletonList(field2Values.get(i)))); - hit.fields(fields); - hits.add(hit); + SearchHitBuilder searchHitBuilder = new SearchHitBuilder(randomInt()) + .addField("field_1", Collections.singletonList(field1Values.get(i))) + .addField("field_2", Collections.singletonList(field2Values.get(i))) + .setSource("{\"field_1\":" + field1Values.get(i) + ",\"field_2\":" + field2Values.get(i) + "}"); + hits.add(searchHitBuilder.build()); } SearchHits searchHits = new SearchHits(hits.toArray(new SearchHit[0]), new TotalHits(hits.size(), TotalHits.Relation.EQUAL_TO), 1); when(searchResponse.getHits()).thenReturn(searchHits); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index 3aa6bfd6480d1..c035c44f117f4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -25,6 +25,7 @@ import java.util.Map; import java.util.stream.Collectors; +import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; import static org.mockito.Mockito.mock; @@ -41,12 +42,13 @@ public void testDetect_GivenFloatField() { .addAggregatableField("some_float", "float").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); assertThat(allFields.get(0).getName(), equalTo("some_float")); + assertThat(allFields.get(0).getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)); } public void testDetect_GivenNumericFieldWithMultipleTypes() { @@ -55,12 +57,13 @@ public void testDetect_GivenNumericFieldWithMultipleTypes() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(1)); assertThat(allFields.get(0).getName(), equalTo("some_number")); + assertThat(allFields.get(0).getExtractionMethod(), equalTo(ExtractedField.ExtractionMethod.DOC_VALUE)); } public void testDetect_GivenNonNumericField() { @@ -68,7 +71,7 @@ public void testDetect_GivenNonNumericField() { .addAggregatableField("some_keyword", "keyword").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); @@ -79,7 +82,7 @@ public void testDetect_GivenFieldWithNumericAndNonNumericTypes() { .addAggregatableField("indecisive_field", "float", "keyword").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); @@ -93,13 +96,15 @@ public void testDetect_GivenMultipleFields() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List allFields = extractedFields.getAllFields(); assertThat(allFields.size(), equalTo(2)); assertThat(allFields.stream().map(ExtractedField::getName).collect(Collectors.toSet()), containsInAnyOrder("some_float", "some_long")); + assertThat(allFields.stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()), + contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); } public void testDetect_GivenIgnoredField() { @@ -107,7 +112,7 @@ public void testDetect_GivenIgnoredField() { .addAggregatableField("_id", "float").build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); @@ -129,7 +134,7 @@ public void testDetect_ShouldSortFieldsAlphabetically() { FieldCapabilitiesResponse fieldCapabilities = mockFieldCapsResponseBuilder.build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -146,7 +151,7 @@ public void testDetectedExtractedFields_GivenIncludeWithMissingField() { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your_field1", "my*"}, new String[0]); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index] with name [your_field1]")); @@ -161,7 +166,7 @@ public void testDetectedExtractedFields_GivenExcludeAllValidFields() { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[0], new String[]{"my_*"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("No compatible fields could be detected in index [source_index]")); } @@ -177,7 +182,7 @@ public void testDetectedExtractedFields_GivenInclusionsAndExclusions() { FetchSourceContext desiredFields = new FetchSourceContext(true, new String[]{"your*", "my_*"}, new String[]{"*nope"}); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(desiredFields), false, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -194,7 +199,7 @@ public void testDetectedExtractedFields_GivenIndexContainsResultsField() { .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), false, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), false, 100, fieldCapabilities); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> extractedFieldsDetector.detect()); assertThat(e.getMessage(), equalTo("Index [source_index] already has a field that matches the dest.results_field [ml]; " + @@ -210,7 +215,7 @@ public void testDetectedExtractedFields_GivenIndexContainsResultsFieldAndTaskIsR .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( - SOURCE_INDEX, buildAnalyticsConfig(), true, fieldCapabilities); + SOURCE_INDEX, buildAnalyticsConfig(), true, 100, fieldCapabilities); ExtractedFields extractedFields = extractedFieldsDetector.detect(); List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) @@ -218,6 +223,63 @@ public void testDetectedExtractedFields_GivenIndexContainsResultsFieldAndTaskIsR assertThat(extractedFieldNames, equalTo(Arrays.asList("my_field1", "your_field2"))); } + public void testDetectedExtractedFields_GivenLessFieldsThanDocValuesLimit() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "float") + .addAggregatableField("field_2", "float") + .addAggregatableField("field_3", "float") + .addAggregatableField("a_keyword", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), true, 4, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, equalTo(Arrays.asList("field_1", "field_2", "field_3"))); + assertThat(extractedFields.getAllFields().stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()), + contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); + } + + public void testDetectedExtractedFields_GivenEqualFieldsToDocValuesLimit() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "float") + .addAggregatableField("field_2", "float") + .addAggregatableField("field_3", "float") + .addAggregatableField("a_keyword", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), true, 3, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, equalTo(Arrays.asList("field_1", "field_2", "field_3"))); + assertThat(extractedFields.getAllFields().stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()), + contains(equalTo(ExtractedField.ExtractionMethod.DOC_VALUE))); + } + + public void testDetectedExtractedFields_GivenMoreFieldsThanDocValuesLimit() { + FieldCapabilitiesResponse fieldCapabilities = new MockFieldCapsResponseBuilder() + .addAggregatableField("field_1", "float") + .addAggregatableField("field_2", "float") + .addAggregatableField("field_3", "float") + .addAggregatableField("a_keyword", "keyword") + .build(); + + ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector( + SOURCE_INDEX, buildAnalyticsConfig(), true, 2, fieldCapabilities); + ExtractedFields extractedFields = extractedFieldsDetector.detect(); + + List extractedFieldNames = extractedFields.getAllFields().stream().map(ExtractedField::getName) + .collect(Collectors.toList()); + assertThat(extractedFieldNames, equalTo(Arrays.asList("field_1", "field_2", "field_3"))); + assertThat(extractedFields.getAllFields().stream().map(ExtractedField::getExtractionMethod).collect(Collectors.toSet()), + contains(equalTo(ExtractedField.ExtractionMethod.SOURCE))); + } + private static DataFrameAnalyticsConfig buildAnalyticsConfig() { return buildAnalyticsConfig(null); } From d86bea5140d094554cbffb5761e96a2bd4a88e13 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 24 Jun 2019 13:52:59 +0300 Subject: [PATCH 66/67] [FEATURE][ML] Stregthen source dest validations for DF analytics (#43399) This adds the following validations: - dest index name is valid - source index exists - dest index is not included in source index - dest index is matching a single index at most --- .../client/MachineLearningIT.java | 8 + .../MlClientDocumentationIT.java | 8 + .../ml/dataframe/DataFrameAnalyticsDest.java | 13 ++ .../DataFrameAnalyticsDestTests.java | 19 ++ .../ml/qa/ml-with-security/build.gradle | 7 + .../TransportPutDataFrameAnalyticsAction.java | 12 +- ...ransportStartDataFrameAnalyticsAction.java | 6 +- .../ml/dataframe/SourceDestValidator.java | 65 +++++++ .../dataframe/SourceDestValidatorTests.java | 176 ++++++++++++++++++ .../test/ml/data_frame_analytics_crud.yml | 171 ++++++++++++++++- .../test/ml/start_data_frame_analytics.yml | 10 +- .../test/ml/stop_data_frame_analytics.yml | 4 + 12 files changed, 490 insertions(+), 9 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/SourceDestValidator.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/SourceDestValidatorTests.java diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index d6550964f9732..12a2fc6c93032 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1216,6 +1216,8 @@ public void testPutDataFrameAnalyticsConfig() throws Exception { .setAnalysis(OutlierDetection.createDefault()) .build(); + createIndex("put-test-source-index", defaultMappingForTest()); + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( new PutDataFrameAnalyticsRequest(config), machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); @@ -1243,6 +1245,8 @@ public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { .setAnalysis(OutlierDetection.createDefault()) .build(); + createIndex("get-test-source-index", defaultMappingForTest()); + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( new PutDataFrameAnalyticsRequest(config), machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); @@ -1256,6 +1260,8 @@ public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { } public void testGetDataFrameAnalyticsConfig_MultipleConfigs() throws Exception { + createIndex("get-test-source-index", defaultMappingForTest()); + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); String configIdPrefix = "get-test-config-"; int numberOfConfigs = 10; @@ -1461,6 +1467,8 @@ public void testDeleteDataFrameAnalyticsConfig() throws Exception { .setAnalysis(OutlierDetection.createDefault()) .build(); + createIndex("delete-test-source-index", defaultMappingForTest()); + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( new GetDataFrameAnalyticsRequest(configId + "*"), machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 0203a3c855d14..f33926582e3ea 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2802,6 +2802,8 @@ public void onFailure(Exception e) { } public void testGetDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + RestHighLevelClient client = highLevelClient(); client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); { @@ -2849,6 +2851,8 @@ public void onFailure(Exception e) { } public void testGetDataFrameAnalyticsStats() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + RestHighLevelClient client = highLevelClient(); client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); { @@ -2897,6 +2901,8 @@ public void onFailure(Exception e) { } public void testPutDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + RestHighLevelClient client = highLevelClient(); { // tag::put-data-frame-analytics-query-config @@ -2988,6 +2994,8 @@ public void onFailure(Exception e) { } public void testDeleteDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + RestHighLevelClient client = highLevelClient(); client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java index 98f1bdeb19189..3bc435336f062 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java @@ -13,11 +13,15 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.indices.InvalidIndexNameException; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; +import java.util.Locale; import java.util.Objects; +import static org.elasticsearch.cluster.metadata.MetaDataCreateIndexService.validateIndexOrAliasName; + public class DataFrameAnalyticsDest implements Writeable, ToXContentObject { public static final ParseField INDEX = new ParseField("index"); @@ -90,4 +94,13 @@ public String getIndex() { public String getResultsField() { return resultsField; } + + public void validate() { + if (index != null) { + validateIndexOrAliasName(index, InvalidIndexNameException::new); + if (index.toLowerCase(Locale.ROOT).equals(index) == false) { + throw new InvalidIndexNameException(index, "dest.index must be lowercase"); + } + } + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java index 7332687723805..bf8ce4c8a99b0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -7,10 +7,14 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.indices.InvalidIndexNameException; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.AbstractSerializingTestCase; import java.io.IOException; +import static org.hamcrest.Matchers.equalTo; + public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase { @Override @@ -33,4 +37,19 @@ public static DataFrameAnalyticsDest createRandom() { protected Writeable.Reader instanceReader() { return DataFrameAnalyticsDest::new; } + + public void testValidate_GivenIndexWithFunkyChars() { + expectThrows(InvalidIndexNameException.class, () -> new DataFrameAnalyticsDest("