diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java index c11e577ef3639..e5a98b4632432 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java @@ -32,12 +32,14 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteFilterRequest; import org.elasticsearch.client.ml.DeleteForecastRequest; import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FlushJobRequest; import org.elasticsearch.client.ml.ForecastJobRequest; @@ -45,6 +47,8 @@ import org.elasticsearch.client.ml.GetCalendarEventsRequest; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCategoriesRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; import org.elasticsearch.client.ml.GetFiltersRequest; @@ -61,12 +65,15 @@ import org.elasticsearch.client.ml.PreviewDatafeedRequest; import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.UpdateDatafeedRequest; import org.elasticsearch.client.ml.UpdateFilterRequest; @@ -581,6 +588,115 @@ static Request deleteCalendarEvent(DeleteCalendarEventRequest deleteCalendarEven return new Request(HttpDelete.METHOD_NAME, endpoint); } + static Request putDataFrameAnalytics(PutDataFrameAnalyticsRequest putRequest) throws IOException { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(putRequest.getConfig().getId()) + .build(); + Request request = new Request(HttpPut.METHOD_NAME, endpoint); + request.setEntity(createEntity(putRequest, REQUEST_BODY_CONTENT_TYPE)); + return request; + } + + static Request getDataFrameAnalytics(GetDataFrameAnalyticsRequest getRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(Strings.collectionToCommaDelimitedString(getRequest.getIds())) + .build(); + Request request = new Request(HttpGet.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (getRequest.getPageParams() != null) { + PageParams pageParams = getRequest.getPageParams(); + if (pageParams.getFrom() != null) { + params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); + } + if (pageParams.getSize() != null) { + params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); + } + } + if (getRequest.getAllowNoMatch() != null) { + params.putParam(GetDataFrameAnalyticsRequest.ALLOW_NO_MATCH.getPreferredName(), Boolean.toString(getRequest.getAllowNoMatch())); + } + request.addParameters(params.asMap()); + return request; + } + + static Request getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest getStatsRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(Strings.collectionToCommaDelimitedString(getStatsRequest.getIds())) + .addPathPartAsIs("_stats") + .build(); + Request request = new Request(HttpGet.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (getStatsRequest.getPageParams() != null) { + PageParams pageParams = getStatsRequest.getPageParams(); + if (pageParams.getFrom() != null) { + params.putParam(PageParams.FROM.getPreferredName(), pageParams.getFrom().toString()); + } + if (pageParams.getSize() != null) { + params.putParam(PageParams.SIZE.getPreferredName(), pageParams.getSize().toString()); + } + } + if (getStatsRequest.getAllowNoMatch() != null) { + params.putParam(GetDataFrameAnalyticsStatsRequest.ALLOW_NO_MATCH.getPreferredName(), + Boolean.toString(getStatsRequest.getAllowNoMatch())); + } + request.addParameters(params.asMap()); + return request; + } + + static Request startDataFrameAnalytics(StartDataFrameAnalyticsRequest startRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(startRequest.getId()) + .addPathPartAsIs("_start") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (startRequest.getTimeout() != null) { + params.withTimeout(startRequest.getTimeout()); + } + request.addParameters(params.asMap()); + return request; + } + + static Request stopDataFrameAnalytics(StopDataFrameAnalyticsRequest stopRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(stopRequest.getId()) + .addPathPartAsIs("_stop") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + RequestConverters.Params params = new RequestConverters.Params(); + if (stopRequest.getTimeout() != null) { + params.withTimeout(stopRequest.getTimeout()); + } + if (stopRequest.getAllowNoMatch() != null) { + params.putParam( + StopDataFrameAnalyticsRequest.ALLOW_NO_MATCH.getPreferredName(), Boolean.toString(stopRequest.getAllowNoMatch())); + } + request.addParameters(params.asMap()); + return request; + } + + static Request deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest deleteRequest) { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "analytics") + .addPathPart(deleteRequest.getId()) + .build(); + return new Request(HttpDelete.METHOD_NAME, endpoint); + } + + static Request evaluateDataFrame(EvaluateDataFrameRequest evaluateRequest) throws IOException { + String endpoint = new EndpointBuilder() + .addPathPartAsIs("_ml", "data_frame", "_evaluate") + .build(); + Request request = new Request(HttpPost.METHOD_NAME, endpoint); + request.setEntity(createEntity(evaluateRequest, REQUEST_BODY_CONTENT_TYPE)); + return request; + } + static Request putFilter(PutFilterRequest putFilterRequest) throws IOException { String endpoint = new EndpointBuilder() .addPathPartAsIs("_ml") diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java index 2e359931c1025..ea72c355a02e7 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java @@ -25,6 +25,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -33,6 +34,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -47,6 +50,10 @@ import org.elasticsearch.client.ml.GetCalendarsResponse; import org.elasticsearch.client.ml.GetCategoriesRequest; import org.elasticsearch.client.ml.GetCategoriesResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -78,6 +85,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -87,8 +96,11 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -1877,4 +1889,286 @@ public void setUpgradeModeAsync(SetUpgradeModeRequest request, RequestOptions op listener, Collections.emptySet()); } + + /** + * Creates a new Data Frame Analytics config + *

+ * For additional info + * see PUT Data Frame Analytics documentation + * + * @param request The {@link PutDataFrameAnalyticsRequest} containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return The {@link PutDataFrameAnalyticsResponse} containing the created + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public PutDataFrameAnalyticsResponse putDataFrameAnalytics(PutDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::putDataFrameAnalytics, + options, + PutDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Creates a new Data Frame Analytics config asynchronously and notifies listener upon completion + *

+ * For additional info + * see PUT Data Frame Analytics documentation + * + * @param request The {@link PutDataFrameAnalyticsRequest} containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void putDataFrameAnalyticsAsync(PutDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::putDataFrameAnalytics, + options, + PutDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Gets a single or multiple Data Frame Analytics configs + *

+ * For additional info + * see GET Data Frame Analytics documentation + * + * @param request The {@link GetDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link GetDataFrameAnalyticsResponse} response object containing the + * {@link org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig} objects + */ + public GetDataFrameAnalyticsResponse getDataFrameAnalytics(GetDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::getDataFrameAnalytics, + options, + GetDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Gets a single or multiple Data Frame Analytics configs asynchronously and notifies listener upon completion + *

+ * For additional info + * see GET Data Frame Analytics documentation + * + * @param request The {@link GetDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void getDataFrameAnalyticsAsync(GetDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::getDataFrameAnalytics, + options, + GetDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Gets the running statistics of a Data Frame Analytics + *

+ * For additional info + * see GET Data Frame Analytics Stats documentation + * + * @param request The {@link GetDataFrameAnalyticsStatsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link GetDataFrameAnalyticsStatsResponse} response object + */ + public GetDataFrameAnalyticsStatsResponse getDataFrameAnalyticsStats(GetDataFrameAnalyticsStatsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::getDataFrameAnalyticsStats, + options, + GetDataFrameAnalyticsStatsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Gets the running statistics of a Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see GET Data Frame Analytics Stats documentation + * + * @param request The {@link GetDataFrameAnalyticsStatsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void getDataFrameAnalyticsStatsAsync(GetDataFrameAnalyticsStatsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::getDataFrameAnalyticsStats, + options, + GetDataFrameAnalyticsStatsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Starts Data Frame Analytics + *

+ * For additional info + * see Start Data Frame Analytics documentation + * + * @param request The {@link StartDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return action acknowledgement + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public AcknowledgedResponse startDataFrameAnalytics(StartDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::startDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Starts Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see Start Data Frame Analytics documentation + * + * @param request The {@link StartDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void startDataFrameAnalyticsAsync(StartDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::startDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Stops Data Frame Analytics + *

+ * For additional info + * see Stop Data Frame Analytics documentation + * + * @param request The {@link StopDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link StopDataFrameAnalyticsResponse} + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public StopDataFrameAnalyticsResponse stopDataFrameAnalytics(StopDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::stopDataFrameAnalytics, + options, + StopDataFrameAnalyticsResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Stops Data Frame Analytics asynchronously and notifies listener upon completion + *

+ * For additional info + * see Stop Data Frame Analytics documentation + * + * @param request The {@link StopDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void stopDataFrameAnalyticsAsync(StopDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::stopDataFrameAnalytics, + options, + StopDataFrameAnalyticsResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Deletes the given Data Frame Analytics config + *

+ * For additional info + * see DELETE Data Frame Analytics documentation + * + * @param request The {@link DeleteDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return action acknowledgement + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public AcknowledgedResponse deleteDataFrameAnalytics(DeleteDataFrameAnalyticsRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::deleteDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Deletes the given Data Frame Analytics config asynchronously and notifies listener upon completion + *

+ * For additional info + * see DELETE Data Frame Analytics documentation + * + * @param request The {@link DeleteDataFrameAnalyticsRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void deleteDataFrameAnalyticsAsync(DeleteDataFrameAnalyticsRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::deleteDataFrameAnalytics, + options, + AcknowledgedResponse::fromXContent, + listener, + Collections.emptySet()); + } + + /** + * Evaluates the given Data Frame + *

+ * For additional info + * see Evaluate Data Frame documentation + * + * @param request The {@link EvaluateDataFrameRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @return {@link EvaluateDataFrameResponse} response object + * @throws IOException when there is a serialization issue sending the request or receiving the response + */ + public EvaluateDataFrameResponse evaluateDataFrame(EvaluateDataFrameRequest request, + RequestOptions options) throws IOException { + return restHighLevelClient.performRequestAndParseEntity(request, + MLRequestConverters::evaluateDataFrame, + options, + EvaluateDataFrameResponse::fromXContent, + Collections.emptySet()); + } + + /** + * Evaluates the given Data Frame asynchronously and notifies listener upon completion + *

+ * For additional info + * see Evaluate Data Frame documentation + * + * @param request The {@link EvaluateDataFrameRequest} + * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized + * @param listener Listener to be notified upon request completion + */ + public void evaluateDataFrameAsync(EvaluateDataFrameRequest request, RequestOptions options, + ActionListener listener) { + restHighLevelClient.performRequestAsyncAndParseEntity(request, + MLRequestConverters::evaluateDataFrame, + options, + EvaluateDataFrameResponse::fromXContent, + listener, + Collections.emptySet()); + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java new file mode 100644 index 0000000000000..940b136c93daa --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/DataFrameNamedXContentProvider.java @@ -0,0 +1,41 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe; + +import org.elasticsearch.client.dataframe.transforms.SyncConfig; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.Arrays; +import java.util.List; + +public class DataFrameNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry(SyncConfig.class, + new ParseField(TimeSyncConfig.NAME), + TimeSyncConfig::fromXContent)); + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java index 34bcb595c206e..355e3ad9bbc0f 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfig.java @@ -30,6 +30,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import java.io.IOException; import java.time.Instant; @@ -44,6 +45,7 @@ public class DataFrameTransformConfig implements ToXContentObject { public static final ParseField SOURCE = new ParseField("source"); public static final ParseField DEST = new ParseField("dest"); public static final ParseField DESCRIPTION = new ParseField("description"); + public static final ParseField SYNC = new ParseField("sync"); public static final ParseField VERSION = new ParseField("version"); public static final ParseField CREATE_TIME = new ParseField("create_time"); // types of transforms @@ -52,6 +54,7 @@ public class DataFrameTransformConfig implements ToXContentObject { private final String id; private final SourceConfig source; private final DestConfig dest; + private final SyncConfig syncConfig; private final PivotConfig pivotConfig; private final String description; private final Version transformVersion; @@ -63,17 +66,26 @@ public class DataFrameTransformConfig implements ToXContentObject { String id = (String) args[0]; SourceConfig source = (SourceConfig) args[1]; DestConfig dest = (DestConfig) args[2]; - PivotConfig pivotConfig = (PivotConfig) args[3]; - String description = (String)args[4]; - Instant createTime = (Instant)args[5]; - String transformVersion = (String)args[6]; - return new DataFrameTransformConfig(id, source, dest, pivotConfig, description, createTime, transformVersion); + SyncConfig syncConfig = (SyncConfig) args[3]; + PivotConfig pivotConfig = (PivotConfig) args[4]; + String description = (String)args[5]; + Instant createTime = (Instant)args[6]; + String transformVersion = (String)args[7]; + return new DataFrameTransformConfig(id, + source, + dest, + syncConfig, + pivotConfig, + description, + createTime, + transformVersion); }); static { PARSER.declareString(constructorArg(), ID); PARSER.declareObject(constructorArg(), (p, c) -> SourceConfig.PARSER.apply(p, null), SOURCE); PARSER.declareObject(constructorArg(), (p, c) -> DestConfig.PARSER.apply(p, null), DEST); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseSyncConfig(p), SYNC); PARSER.declareObject(optionalConstructorArg(), (p, c) -> PivotConfig.fromXContent(p), PIVOT_TRANSFORM); PARSER.declareString(optionalConstructorArg(), DESCRIPTION); PARSER.declareField(optionalConstructorArg(), @@ -81,6 +93,15 @@ public class DataFrameTransformConfig implements ToXContentObject { PARSER.declareString(optionalConstructorArg(), VERSION); } + private static SyncConfig parseSyncConfig(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + SyncConfig syncConfig = parser.namedObject(SyncConfig.class, parser.currentName(), true); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return syncConfig; + } + + public static DataFrameTransformConfig fromXContent(final XContentParser parser) { return PARSER.apply(parser, null); } @@ -97,12 +118,13 @@ public static DataFrameTransformConfig fromXContent(final XContentParser parser) * @return A DataFrameTransformConfig to preview, NOTE it will have a {@code null} id, destination and index. */ public static DataFrameTransformConfig forPreview(final SourceConfig source, final PivotConfig pivotConfig) { - return new DataFrameTransformConfig(null, source, null, pivotConfig, null, null, null); + return new DataFrameTransformConfig(null, source, null, null, pivotConfig, null, null, null); } DataFrameTransformConfig(final String id, final SourceConfig source, final DestConfig dest, + final SyncConfig syncConfig, final PivotConfig pivotConfig, final String description, final Instant createTime, @@ -110,6 +132,7 @@ public static DataFrameTransformConfig forPreview(final SourceConfig source, fin this.id = id; this.source = source; this.dest = dest; + this.syncConfig = syncConfig; this.pivotConfig = pivotConfig; this.description = description; this.createTime = createTime == null ? null : Instant.ofEpochMilli(createTime.toEpochMilli()); @@ -128,6 +151,10 @@ public DestConfig getDestination() { return dest; } + public SyncConfig getSyncConfig() { + return syncConfig; + } + public PivotConfig getPivotConfig() { return pivotConfig; } @@ -157,6 +184,11 @@ public XContentBuilder toXContent(final XContentBuilder builder, final Params pa if (dest != null) { builder.field(DEST.getPreferredName(), dest); } + if (syncConfig != null) { + builder.startObject(SYNC.getPreferredName()); + builder.field(syncConfig.getName(), syncConfig); + builder.endObject(); + } if (pivotConfig != null) { builder.field(PIVOT_TRANSFORM.getPreferredName(), pivotConfig); } @@ -189,6 +221,7 @@ public boolean equals(Object other) { && Objects.equals(this.source, that.source) && Objects.equals(this.dest, that.dest) && Objects.equals(this.description, that.description) + && Objects.equals(this.syncConfig, that.syncConfig) && Objects.equals(this.transformVersion, that.transformVersion) && Objects.equals(this.createTime, that.createTime) && Objects.equals(this.pivotConfig, that.pivotConfig); @@ -196,7 +229,7 @@ public boolean equals(Object other) { @Override public int hashCode() { - return Objects.hash(id, source, dest, pivotConfig, description, createTime, transformVersion); + return Objects.hash(id, source, dest, syncConfig, pivotConfig, description); } @Override @@ -213,6 +246,7 @@ public static class Builder { private String id; private SourceConfig source; private DestConfig dest; + private SyncConfig syncConfig; private PivotConfig pivotConfig; private String description; @@ -231,6 +265,11 @@ public Builder setDest(DestConfig dest) { return this; } + public Builder setSyncConfig(SyncConfig syncConfig) { + this.syncConfig = syncConfig; + return this; + } + public Builder setPivotConfig(PivotConfig pivotConfig) { this.pivotConfig = pivotConfig; return this; @@ -242,7 +281,7 @@ public Builder setDescription(String description) { } public DataFrameTransformConfig build() { - return new DataFrameTransformConfig(id, source, dest, pivotConfig, description, null, null); + return new DataFrameTransformConfig(id, source, dest, syncConfig, pivotConfig, description, null, null); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java new file mode 100644 index 0000000000000..3ead35d0a491a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/SyncConfig.java @@ -0,0 +1,30 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +public interface SyncConfig extends ToXContentObject { + + /** + * Returns the name of the writeable object + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java new file mode 100644 index 0000000000000..797ca3f896138 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfig.java @@ -0,0 +1,108 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class TimeSyncConfig implements SyncConfig { + + public static final String NAME = "time"; + + private static final ParseField FIELD = new ParseField("field"); + private static final ParseField DELAY = new ParseField("delay"); + + private final String field; + private final TimeValue delay; + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("time_sync_config", true, + args -> new TimeSyncConfig((String) args[0], args[1] != null ? (TimeValue) args[1] : TimeValue.ZERO)); + + static { + PARSER.declareString(constructorArg(), FIELD); + PARSER.declareField(optionalConstructorArg(), (p, c) -> TimeValue.parseTimeValue(p.textOrNull(), DELAY.getPreferredName()), DELAY, + ObjectParser.ValueType.STRING_OR_NULL); + } + + public static TimeSyncConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public TimeSyncConfig(String field, TimeValue delay) { + this.field = field; + this.delay = delay; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD.getPreferredName(), field); + if (delay.duration() > 0) { + builder.field(DELAY.getPreferredName(), delay.getStringRep()); + } + builder.endObject(); + return builder; + } + + public String getField() { + return field; + } + + public TimeValue getDelay() { + return delay; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + final TimeSyncConfig that = (TimeSyncConfig) other; + + return Objects.equals(this.field, that.field) + && Objects.equals(this.delay, that.delay); + } + + @Override + public int hashCode() { + return Objects.hash(field, delay); + } + + @Override + public String getName() { + return NAME; + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..f03466632304d --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequest.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; + +import java.util.Objects; +import java.util.Optional; + +/** + * Request to delete a data frame analytics config + */ +public class DeleteDataFrameAnalyticsRequest implements Validatable { + + private final String id; + + public DeleteDataFrameAnalyticsRequest(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DeleteDataFrameAnalyticsRequest other = (DeleteDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java new file mode 100644 index 0000000000000..2e3bbb170509c --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameRequest.java @@ -0,0 +1,136 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +public class EvaluateDataFrameRequest implements ToXContentObject, Validatable { + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField EVALUATION = new ParseField("evaluation"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "evaluate_data_frame_request", true, args -> new EvaluateDataFrameRequest((List) args[0], (Evaluation) args[1])); + + static { + PARSER.declareStringArray(constructorArg(), INDEX); + PARSER.declareObject(constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); + } + + private static Evaluation parseEvaluation(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + Evaluation evaluation = parser.namedObject(Evaluation.class, parser.currentName(), null); + ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return evaluation; + } + + public static EvaluateDataFrameRequest fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private List indices; + private Evaluation evaluation; + + public EvaluateDataFrameRequest(String index, Evaluation evaluation) { + this(Arrays.asList(index), evaluation); + } + + public EvaluateDataFrameRequest(List indices, Evaluation evaluation) { + setIndices(indices); + setEvaluation(evaluation); + } + + public List getIndices() { + return Collections.unmodifiableList(indices); + } + + public final void setIndices(List indices) { + Objects.requireNonNull(indices); + this.indices = new ArrayList<>(indices); + } + + public Evaluation getEvaluation() { + return evaluation; + } + + public final void setEvaluation(Evaluation evaluation) { + this.evaluation = evaluation; + } + + @Override + public Optional validate() { + List errors = new ArrayList<>(); + if (indices.isEmpty()) { + errors.add("At least one index must be specified"); + } + if (evaluation == null) { + errors.add("evaluation must not be null"); + } + return errors.isEmpty() + ? Optional.empty() + : Optional.of(ValidationException.withErrors(errors)); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .array(INDEX.getPreferredName(), indices.toArray()) + .startObject(EVALUATION.getPreferredName()) + .field(evaluation.getName(), evaluation) + .endObject() + .endObject(); + } + + @Override + public int hashCode() { + return Objects.hash(indices, evaluation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + EvaluateDataFrameRequest that = (EvaluateDataFrameRequest) o; + return Objects.equals(indices, that.indices) + && Objects.equals(evaluation, that.evaluation); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java new file mode 100644 index 0000000000000..0709021ed4bd5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EvaluateDataFrameResponse.java @@ -0,0 +1,117 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.NamedObjectNotFoundException; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; + +public class EvaluateDataFrameResponse implements ToXContentObject { + + public static EvaluateDataFrameResponse fromXContent(XContentParser parser) throws IOException { + if (parser.currentToken() == null) { + parser.nextToken(); + } + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + String evaluationName = parser.currentName(); + parser.nextToken(); + Map metrics = parser.map(LinkedHashMap::new, EvaluateDataFrameResponse::parseMetric); + List knownMetrics = + metrics.values().stream() + .filter(Objects::nonNull) // Filter out null values returned by {@link EvaluateDataFrameResponse::parseMetric}. + .collect(Collectors.toList()); + ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return new EvaluateDataFrameResponse(evaluationName, knownMetrics); + } + + private static EvaluationMetric.Result parseMetric(XContentParser parser) throws IOException { + String metricName = parser.currentName(); + try { + return parser.namedObject(EvaluationMetric.Result.class, metricName, null); + } catch (NamedObjectNotFoundException e) { + parser.skipChildren(); + // Metric name not recognized. Return {@code null} value here and filter it out later. + return null; + } + } + + private final String evaluationName; + private final Map metrics; + + public EvaluateDataFrameResponse(String evaluationName, List metrics) { + this.evaluationName = Objects.requireNonNull(evaluationName); + this.metrics = Objects.requireNonNull(metrics).stream().collect(Collectors.toUnmodifiableMap(m -> m.getMetricName(), m -> m)); + } + + public String getEvaluationName() { + return evaluationName; + } + + public List getMetrics() { + return metrics.values().stream().collect(Collectors.toList()); + } + + @SuppressWarnings("unchecked") + public T getMetricByName(String metricName) { + Objects.requireNonNull(metricName); + return (T) metrics.get(metricName); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(evaluationName, metrics) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + EvaluateDataFrameResponse that = (EvaluateDataFrameResponse) o; + return Objects.equals(evaluationName, that.evaluationName) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(evaluationName, metrics); + } + + @Override + public final String toString() { + return Strings.toString(this); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..40698c4b528fa --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequest.java @@ -0,0 +1,104 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.core.PageParams; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +public class GetDataFrameAnalyticsRequest implements Validatable { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private final List ids; + private Boolean allowNoMatch; + private PageParams pageParams; + + /** + * Helper method to create a request that will get ALL Data Frame Analytics + * @return new {@link GetDataFrameAnalyticsRequest} object for the id "_all" + */ + public static GetDataFrameAnalyticsRequest getAllDataFrameAnalyticsRequest() { + return new GetDataFrameAnalyticsRequest("_all"); + } + + public GetDataFrameAnalyticsRequest(String... ids) { + this.ids = Arrays.asList(ids); + } + + public List getIds() { + return ids; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + /** + * Whether to ignore if a wildcard expression matches no data frame analytics. + * + * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) + * does not match any data frame analytics + */ + public GetDataFrameAnalyticsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + public PageParams getPageParams() { + return pageParams; + } + + public GetDataFrameAnalyticsRequest setPageParams(@Nullable PageParams pageParams) { + this.pageParams = pageParams; + return this; + } + + @Override + public Optional validate() { + if (ids == null || ids.isEmpty()) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsRequest other = (GetDataFrameAnalyticsRequest) o; + return Objects.equals(ids, other.ids) + && Objects.equals(allowNoMatch, other.allowNoMatch) + && Objects.equals(pageParams, other.pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(ids, allowNoMatch, pageParams); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..76996e9d4d0b6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsResponse.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class GetDataFrameAnalyticsResponse { + + public static final ParseField DATA_FRAME_ANALYTICS = new ParseField("data_frame_analytics"); + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "get_data_frame_analytics", + true, + args -> new GetDataFrameAnalyticsResponse((List) args[0])); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> DataFrameAnalyticsConfig.fromXContent(p), DATA_FRAME_ANALYTICS); + } + + public static GetDataFrameAnalyticsResponse fromXContent(final XContentParser parser) { + return PARSER.apply(parser, null); + } + + private List analytics; + + public GetDataFrameAnalyticsResponse(List analytics) { + this.analytics = analytics; + } + + public List getAnalytics() { + return analytics; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsResponse other = (GetDataFrameAnalyticsResponse) o; + return Objects.equals(this.analytics, other.analytics); + } + + @Override + public int hashCode() { + return Objects.hash(analytics); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java new file mode 100644 index 0000000000000..f1e4a35fb661b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequest.java @@ -0,0 +1,99 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.core.PageParams; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +/** + * Request to get data frame analytics stats + */ +public class GetDataFrameAnalyticsStatsRequest implements Validatable { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private final List ids; + private Boolean allowNoMatch; + private PageParams pageParams; + + public GetDataFrameAnalyticsStatsRequest(String... ids) { + this.ids = Arrays.asList(ids); + } + + public List getIds() { + return ids; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + /** + * Whether to ignore if a wildcard expression matches no data frame analytics. + * + * @param allowNoMatch If this is {@code false}, then an error is returned when a wildcard (or {@code _all}) + * does not match any data frame analytics + */ + public GetDataFrameAnalyticsStatsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + public PageParams getPageParams() { + return pageParams; + } + + public GetDataFrameAnalyticsStatsRequest setPageParams(@Nullable PageParams pageParams) { + this.pageParams = pageParams; + return this; + } + + @Override + public Optional validate() { + if (ids == null || ids.isEmpty()) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsStatsRequest other = (GetDataFrameAnalyticsStatsRequest) o; + return Objects.equals(ids, other.ids) + && Objects.equals(allowNoMatch, other.allowNoMatch) + && Objects.equals(pageParams, other.pageParams); + } + + @Override + public int hashCode() { + return Objects.hash(ids, allowNoMatch, pageParams); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java new file mode 100644 index 0000000000000..5391a576e98b0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsResponse.java @@ -0,0 +1,102 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.client.dataframe.AcknowledgedTasksResponse; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class GetDataFrameAnalyticsStatsResponse { + + public static GetDataFrameAnalyticsStatsResponse fromXContent(XContentParser parser) { + return GetDataFrameAnalyticsStatsResponse.PARSER.apply(parser, null); + } + + private static final ParseField DATA_FRAME_ANALYTICS = new ParseField("data_frame_analytics"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "get_data_frame_analytics_stats_response", true, + args -> new GetDataFrameAnalyticsStatsResponse( + (List) args[0], + (List) args[1], + (List) args[2])); + + static { + PARSER.declareObjectArray(constructorArg(), (p, c) -> DataFrameAnalyticsStats.fromXContent(p), DATA_FRAME_ANALYTICS); + PARSER.declareObjectArray( + optionalConstructorArg(), (p, c) -> TaskOperationFailure.fromXContent(p), AcknowledgedTasksResponse.TASK_FAILURES); + PARSER.declareObjectArray( + optionalConstructorArg(), (p, c) -> ElasticsearchException.fromXContent(p), AcknowledgedTasksResponse.NODE_FAILURES); + } + + private final List analyticsStats; + private final List taskFailures; + private final List nodeFailures; + + public GetDataFrameAnalyticsStatsResponse(List analyticsStats, + @Nullable List taskFailures, + @Nullable List nodeFailures) { + this.analyticsStats = analyticsStats; + this.taskFailures = taskFailures == null ? Collections.emptyList() : Collections.unmodifiableList(taskFailures); + this.nodeFailures = nodeFailures == null ? Collections.emptyList() : Collections.unmodifiableList(nodeFailures); + } + + public List getAnalyticsStats() { + return analyticsStats; + } + + public List getNodeFailures() { + return nodeFailures; + } + + public List getTaskFailures() { + return taskFailures; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + GetDataFrameAnalyticsStatsResponse other = (GetDataFrameAnalyticsStatsResponse) o; + return Objects.equals(analyticsStats, other.analyticsStats) + && Objects.equals(nodeFailures, other.nodeFailures) + && Objects.equals(taskFailures, other.taskFailures); + } + + @Override + public int hashCode() { + return Objects.hash(analyticsStats, nodeFailures, taskFailures); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java index 892df340abd6b..a0f0d25f2ca01 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/NodeAttributes.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.ml; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -147,4 +148,9 @@ public boolean equals(Object other) { Objects.equals(transportAddress, that.transportAddress) && Objects.equals(attributes, that.attributes); } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..14950a74c9187 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequest.java @@ -0,0 +1,70 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; + +public class PutDataFrameAnalyticsRequest implements ToXContentObject, Validatable { + + private final DataFrameAnalyticsConfig config; + + public PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfig config) { + this.config = config; + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public Optional validate() { + if (config == null) { + return Optional.of(ValidationException.withError("put requires a non-null data frame analytics config")); + } + return Optional.empty(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return config.toXContent(builder, params); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PutDataFrameAnalyticsRequest other = (PutDataFrameAnalyticsRequest) o; + return Objects.equals(config, other.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..e6c4be15987d4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsResponse.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class PutDataFrameAnalyticsResponse { + + public static PutDataFrameAnalyticsResponse fromXContent(XContentParser parser) throws IOException { + return new PutDataFrameAnalyticsResponse(DataFrameAnalyticsConfig.fromXContent(parser)); + } + + private final DataFrameAnalyticsConfig config; + + public PutDataFrameAnalyticsResponse(DataFrameAnalyticsConfig config) { + this.config = config; + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + PutDataFrameAnalyticsResponse other = (PutDataFrameAnalyticsResponse) o; + return Objects.equals(config, other.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..68a925d15019a --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequest.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.unit.TimeValue; + +import java.util.Objects; +import java.util.Optional; + +public class StartDataFrameAnalyticsRequest implements Validatable { + + private final String id; + private TimeValue timeout; + + public StartDataFrameAnalyticsRequest(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public TimeValue getTimeout() { + return timeout; + } + + public StartDataFrameAnalyticsRequest setTimeout(@Nullable TimeValue timeout) { + this.timeout = timeout; + return this; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StartDataFrameAnalyticsRequest other = (StartDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id) + && Objects.equals(timeout, other.timeout); + } + + @Override + public int hashCode() { + return Objects.hash(id, timeout); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java new file mode 100644 index 0000000000000..9608d40fc7d16 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequest.java @@ -0,0 +1,88 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.Validatable; +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.unit.TimeValue; + +import java.util.Objects; +import java.util.Optional; + +public class StopDataFrameAnalyticsRequest implements Validatable { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private final String id; + private TimeValue timeout; + private Boolean allowNoMatch; + + public StopDataFrameAnalyticsRequest(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public TimeValue getTimeout() { + return timeout; + } + + public StopDataFrameAnalyticsRequest setTimeout(@Nullable TimeValue timeout) { + this.timeout = timeout; + return this; + } + + public Boolean getAllowNoMatch() { + return allowNoMatch; + } + + public StopDataFrameAnalyticsRequest setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + return this; + } + + @Override + public Optional validate() { + if (id == null) { + return Optional.of(ValidationException.withError("data frame analytics id must not be null")); + } + return Optional.empty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StopDataFrameAnalyticsRequest other = (StopDataFrameAnalyticsRequest) o; + return Objects.equals(id, other.id) + && Objects.equals(timeout, other.timeout) + && Objects.equals(allowNoMatch, other.allowNoMatch); + } + + @Override + public int hashCode() { + return Objects.hash(id, timeout, allowNoMatch); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java new file mode 100644 index 0000000000000..5f45c6f9ea51f --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponse.java @@ -0,0 +1,87 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +/** + * Response indicating if the Machine Learning Data Frame Analytics is now stopped or not + */ +public class StopDataFrameAnalyticsResponse implements ToXContentObject { + + private static final ParseField STOPPED = new ParseField("stopped"); + + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "stop_data_frame_analytics_response", + true, + args -> new StopDataFrameAnalyticsResponse((Boolean) args[0])); + + static { + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), STOPPED); + } + + public static StopDataFrameAnalyticsResponse fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final boolean stopped; + + public StopDataFrameAnalyticsResponse(boolean stopped) { + this.stopped = stopped; + } + + /** + * Has the Data Frame Analytics stopped or not + * + * @return boolean value indicating the Data Frame Analytics stopped status + */ + public boolean isStopped() { + return stopped; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StopDataFrameAnalyticsResponse other = (StopDataFrameAnalyticsResponse) o; + return stopped == other.stopped; + } + + @Override + public int hashCode() { + return Objects.hash(stopped); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(STOPPED.getPreferredName(), stopped) + .endObject(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java new file mode 100644 index 0000000000000..81b19eefce573 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalysis.java @@ -0,0 +1,27 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +public interface DataFrameAnalysis extends ToXContentObject { + + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java new file mode 100644 index 0000000000000..b1309e66afcd4 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfig.java @@ -0,0 +1,208 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING; +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.VALUE; + +public class DataFrameAnalyticsConfig implements ToXContentObject { + + public static DataFrameAnalyticsConfig fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static Builder builder(String id) { + return new Builder().setId(id); + } + + private static final ParseField ID = new ParseField("id"); + private static final ParseField SOURCE = new ParseField("source"); + private static final ParseField DEST = new ParseField("dest"); + private static final ParseField ANALYSIS = new ParseField("analysis"); + private static final ParseField ANALYZED_FIELDS = new ParseField("analyzed_fields"); + private static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); + + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_config", true, Builder::new); + + static { + PARSER.declareString(Builder::setId, ID); + PARSER.declareObject(Builder::setSource, (p, c) -> DataFrameAnalyticsSource.fromXContent(p), SOURCE); + PARSER.declareObject(Builder::setDest, (p, c) -> DataFrameAnalyticsDest.fromXContent(p), DEST); + PARSER.declareObject(Builder::setAnalysis, (p, c) -> parseAnalysis(p), ANALYSIS); + PARSER.declareField(Builder::setAnalyzedFields, + (p, c) -> FetchSourceContext.fromXContent(p), + ANALYZED_FIELDS, + OBJECT_ARRAY_BOOLEAN_OR_STRING); + PARSER.declareField(Builder::setModelMemoryLimit, + (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), MODEL_MEMORY_LIMIT.getPreferredName()), MODEL_MEMORY_LIMIT, VALUE); + } + + private static DataFrameAnalysis parseAnalysis(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + DataFrameAnalysis analysis = parser.namedObject(DataFrameAnalysis.class, parser.currentName(), true); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return analysis; + } + + private final String id; + private final DataFrameAnalyticsSource source; + private final DataFrameAnalyticsDest dest; + private final DataFrameAnalysis analysis; + private final FetchSourceContext analyzedFields; + private final ByteSizeValue modelMemoryLimit; + + private DataFrameAnalyticsConfig(String id, DataFrameAnalyticsSource source, DataFrameAnalyticsDest dest, DataFrameAnalysis analysis, + @Nullable FetchSourceContext analyzedFields, @Nullable ByteSizeValue modelMemoryLimit) { + this.id = Objects.requireNonNull(id); + this.source = Objects.requireNonNull(source); + this.dest = Objects.requireNonNull(dest); + this.analysis = Objects.requireNonNull(analysis); + this.analyzedFields = analyzedFields; + this.modelMemoryLimit = modelMemoryLimit; + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsSource getSource() { + return source; + } + + public DataFrameAnalyticsDest getDest() { + return dest; + } + + public DataFrameAnalysis getAnalysis() { + return analysis; + } + + public FetchSourceContext getAnalyzedFields() { + return analyzedFields; + } + + public ByteSizeValue getModelMemoryLimit() { + return modelMemoryLimit; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ID.getPreferredName(), id); + builder.field(SOURCE.getPreferredName(), source); + builder.field(DEST.getPreferredName(), dest); + builder.startObject(ANALYSIS.getPreferredName()); + builder.field(analysis.getName(), analysis); + builder.endObject(); + if (analyzedFields != null) { + builder.field(ANALYZED_FIELDS.getPreferredName(), analyzedFields); + } + if (modelMemoryLimit != null) { + builder.field(MODEL_MEMORY_LIMIT.getPreferredName(), modelMemoryLimit.getStringRep()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsConfig other = (DataFrameAnalyticsConfig) o; + return Objects.equals(id, other.id) + && Objects.equals(source, other.source) + && Objects.equals(dest, other.dest) + && Objects.equals(analysis, other.analysis) + && Objects.equals(analyzedFields, other.analyzedFields) + && Objects.equals(modelMemoryLimit, other.modelMemoryLimit); + } + + @Override + public int hashCode() { + return Objects.hash(id, source, dest, analysis, analyzedFields, getModelMemoryLimit()); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static class Builder { + + private String id; + private DataFrameAnalyticsSource source; + private DataFrameAnalyticsDest dest; + private DataFrameAnalysis analysis; + private FetchSourceContext analyzedFields; + private ByteSizeValue modelMemoryLimit; + + private Builder() {} + + public Builder setId(String id) { + this.id = Objects.requireNonNull(id); + return this; + } + + public Builder setSource(DataFrameAnalyticsSource source) { + this.source = Objects.requireNonNull(source); + return this; + } + + public Builder setDest(DataFrameAnalyticsDest dest) { + this.dest = Objects.requireNonNull(dest); + return this; + } + + public Builder setAnalysis(DataFrameAnalysis analysis) { + this.analysis = Objects.requireNonNull(analysis); + return this; + } + + public Builder setAnalyzedFields(FetchSourceContext fields) { + this.analyzedFields = fields; + return this; + } + + public Builder setModelMemoryLimit(ByteSizeValue modelMemoryLimit) { + this.modelMemoryLimit = modelMemoryLimit; + return this; + } + + public DataFrameAnalyticsConfig build() { + return new DataFrameAnalyticsConfig(id, source, dest, analysis, analyzedFields, modelMemoryLimit); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java new file mode 100644 index 0000000000000..4123f85ee2f43 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDest.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class DataFrameAnalyticsDest implements ToXContentObject { + + public static DataFrameAnalyticsDest fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static Builder builder() { + return new Builder(); + } + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField RESULTS_FIELD = new ParseField("results_field"); + + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_dest", true, Builder::new); + + static { + PARSER.declareString(Builder::setIndex, INDEX); + PARSER.declareString(Builder::setResultsField, RESULTS_FIELD); + } + + private final String index; + private final String resultsField; + + private DataFrameAnalyticsDest(String index, @Nullable String resultsField) { + this.index = requireNonNull(index); + this.resultsField = resultsField; + } + + public String getIndex() { + return index; + } + + public String getResultsField() { + return resultsField; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + if (resultsField != null) { + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsDest other = (DataFrameAnalyticsDest) o; + return Objects.equals(index, other.index) + && Objects.equals(resultsField, other.resultsField); + } + + @Override + public int hashCode() { + return Objects.hash(index, resultsField); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static class Builder { + + private String index; + private String resultsField; + + private Builder() {} + + public Builder setIndex(String index) { + this.index = index; + return this; + } + + public Builder setResultsField(String resultsField) { + this.resultsField = resultsField; + return this; + } + + public DataFrameAnalyticsDest build() { + return new DataFrameAnalyticsDest(index, resultsField); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java new file mode 100644 index 0000000000000..c36799cd3b4a7 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSource.java @@ -0,0 +1,121 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class DataFrameAnalyticsSource implements ToXContentObject { + + public static DataFrameAnalyticsSource fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static Builder builder() { + return new Builder(); + } + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField QUERY = new ParseField("query"); + + private static ObjectParser PARSER = new ObjectParser<>("data_frame_analytics_source", true, Builder::new); + + static { + PARSER.declareString(Builder::setIndex, INDEX); + PARSER.declareObject(Builder::setQueryConfig, (p, c) -> QueryConfig.fromXContent(p), QUERY); + } + + private final String index; + private final QueryConfig queryConfig; + + private DataFrameAnalyticsSource(String index, @Nullable QueryConfig queryConfig) { + this.index = Objects.requireNonNull(index); + this.queryConfig = queryConfig; + } + + public String getIndex() { + return index; + } + + public QueryConfig getQueryConfig() { + return queryConfig; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + if (queryConfig != null) { + builder.field(QUERY.getPreferredName(), queryConfig.getQuery()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsSource other = (DataFrameAnalyticsSource) o; + return Objects.equals(index, other.index) + && Objects.equals(queryConfig, other.queryConfig); + } + + @Override + public int hashCode() { + return Objects.hash(index, queryConfig); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public static class Builder { + + private String index; + private QueryConfig queryConfig; + + private Builder() {} + + public Builder setIndex(String index) { + this.index = index; + return this; + } + + public Builder setQueryConfig(QueryConfig queryConfig) { + this.queryConfig = queryConfig; + return this; + } + + public DataFrameAnalyticsSource build() { + return new DataFrameAnalyticsSource(index, queryConfig); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java new file mode 100644 index 0000000000000..6ee349b8e8d38 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsState.java @@ -0,0 +1,34 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import java.util.Locale; + +public enum DataFrameAnalyticsState { + STARTED, REINDEXING, ANALYZING, STOPPING, STOPPED; + + public static DataFrameAnalyticsState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public String value() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java new file mode 100644 index 0000000000000..5c652f33edb2e --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java @@ -0,0 +1,133 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.client.ml.NodeAttributes; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.inject.internal.ToStringBuilder; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class DataFrameAnalyticsStats { + + public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + static final ParseField ID = new ParseField("id"); + static final ParseField STATE = new ParseField("state"); + static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent"); + static final ParseField NODE = new ParseField("node"); + static final ParseField ASSIGNMENT_EXPLANATION = new ParseField("assignment_explanation"); + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("data_frame_analytics_stats", true, + args -> new DataFrameAnalyticsStats( + (String) args[0], + (DataFrameAnalyticsState) args[1], + (Integer) args[2], + (NodeAttributes) args[3], + (String) args[4])); + + static { + PARSER.declareString(constructorArg(), ID); + PARSER.declareField(constructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return DataFrameAnalyticsState.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, STATE, ObjectParser.ValueType.STRING); + PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); + PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE); + PARSER.declareString(optionalConstructorArg(), ASSIGNMENT_EXPLANATION); + } + + private final String id; + private final DataFrameAnalyticsState state; + private final Integer progressPercent; + private final NodeAttributes node; + private final String assignmentExplanation; + + public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable Integer progressPercent, + @Nullable NodeAttributes node, @Nullable String assignmentExplanation) { + this.id = id; + this.state = state; + this.progressPercent = progressPercent; + this.node = node; + this.assignmentExplanation = assignmentExplanation; + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsState getState() { + return state; + } + + public Integer getProgressPercent() { + return progressPercent; + } + + public NodeAttributes getNode() { + return node; + } + + public String getAssignmentExplanation() { + return assignmentExplanation; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsStats other = (DataFrameAnalyticsStats) o; + return Objects.equals(id, other.id) + && Objects.equals(state, other.state) + && Objects.equals(progressPercent, other.progressPercent) + && Objects.equals(node, other.node) + && Objects.equals(assignmentExplanation, other.assignmentExplanation); + } + + @Override + public int hashCode() { + return Objects.hash(id, state, progressPercent, node, assignmentExplanation); + } + + @Override + public String toString() { + return new ToStringBuilder(getClass()) + .add("id", id) + .add("state", state) + .add("progressPercent", progressPercent) + .add("node", node) + .add("assignmentExplanation", assignmentExplanation) + .toString(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java new file mode 100644 index 0000000000000..3b78c60be91fd --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/MlDataFrameAnalysisNamedXContentProvider.java @@ -0,0 +1,37 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.Arrays; +import java.util.List; + +public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry( + DataFrameAnalysis.class, + OutlierDetection.NAME, + (p, c) -> OutlierDetection.fromXContent(p))); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java new file mode 100644 index 0000000000000..946c01ac5c835 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/OutlierDetection.java @@ -0,0 +1,176 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; + +public class OutlierDetection implements DataFrameAnalysis { + + public static OutlierDetection fromXContent(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static OutlierDetection createDefault() { + return builder().build(); + } + + public static Builder builder() { + return new Builder(); + } + + public static final ParseField NAME = new ParseField("outlier_detection"); + static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); + static final ParseField METHOD = new ParseField("method"); + public static final ParseField MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE = + new ParseField("minimum_score_to_write_feature_influence"); + + private static ObjectParser PARSER = new ObjectParser<>(NAME.getPreferredName(), true, Builder::new); + + static { + PARSER.declareInt(Builder::setNNeighbors, N_NEIGHBORS); + PARSER.declareField(Builder::setMethod, p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return Method.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, METHOD, ObjectParser.ValueType.STRING); + PARSER.declareDouble(Builder::setMinScoreToWriteFeatureInfluence, MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE); + } + + private final Integer nNeighbors; + private final Method method; + private final Double minScoreToWriteFeatureInfluence; + + /** + * Constructs the outlier detection configuration + * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. + * @param method The method. Leave unspecified for a dynamic mixture of methods. + * @param minScoreToWriteFeatureInfluence The min outlier score required to calculate feature influence. Defaults to 0.1. + */ + private OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method, @Nullable Double minScoreToWriteFeatureInfluence) { + this.nNeighbors = nNeighbors; + this.method = method; + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + public Integer getNNeighbors() { + return nNeighbors; + } + + public Method getMethod() { + return method; + } + + public Double getMinScoreToWriteFeatureInfluence() { + return minScoreToWriteFeatureInfluence; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (nNeighbors != null) { + builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + builder.field(METHOD.getPreferredName(), method); + } + if (minScoreToWriteFeatureInfluence != null) { + builder.field(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + OutlierDetection other = (OutlierDetection) o; + return Objects.equals(nNeighbors, other.nNeighbors) + && Objects.equals(method, other.method) + && Objects.equals(minScoreToWriteFeatureInfluence, other.minScoreToWriteFeatureInfluence); + } + + @Override + public int hashCode() { + return Objects.hash(nNeighbors, method, minScoreToWriteFeatureInfluence); + } + + @Override + public String toString() { + return Strings.toString(this); + } + + public enum Method { + LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; + + public static Method fromString(String value) { + return Method.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } + + public static class Builder { + + private Integer nNeighbors; + private Method method; + private Double minScoreToWriteFeatureInfluence; + + private Builder() {} + + public Builder setNNeighbors(Integer nNeighbors) { + this.nNeighbors = nNeighbors; + return this; + } + + public Builder setMethod(Method method) { + this.method = method; + return this; + } + + public Builder setMinScoreToWriteFeatureInfluence(Double minScoreToWriteFeatureInfluence) { + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; + return this; + } + + public OutlierDetection build() { + return new OutlierDetection(nNeighbors, method, minScoreToWriteFeatureInfluence); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java new file mode 100644 index 0000000000000..ae704db9f800e --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/QueryConfig.java @@ -0,0 +1,82 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.AbstractQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; + +import java.io.IOException; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +/** + * Object for encapsulating the desired Query for a DataFrameAnalysis + */ +public class QueryConfig implements ToXContentObject { + + public static QueryConfig fromXContent(XContentParser parser) throws IOException { + QueryBuilder query = AbstractQueryBuilder.parseInnerQueryBuilder(parser); + return new QueryConfig(query); + } + + private final QueryBuilder query; + + public QueryConfig(QueryBuilder query) { + this.query = requireNonNull(query); + } + + public QueryConfig(QueryConfig queryConfig) { + this(requireNonNull(queryConfig).query); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + query.toXContent(builder, params); + return builder; + } + + public QueryBuilder getQuery() { + return query; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + QueryConfig other = (QueryConfig) o; + return Objects.equals(query, other.query); + } + + @Override + public int hashCode() { + return Objects.hash(query); + } + + @Override + public String toString() { + return Strings.toString(this); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java new file mode 100644 index 0000000000000..78578597e195b --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/Evaluation.java @@ -0,0 +1,32 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Defines an evaluation + */ +public interface Evaluation extends ToXContentObject { + + /** + * Returns the evaluation name + */ + String getName(); +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java new file mode 100644 index 0000000000000..a0f77838f1fd0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/EvaluationMetric.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * Defines an evaluation metric + */ +public interface EvaluationMetric extends ToXContentObject { + + /** + * Returns the name of the metric + */ + String getName(); + + /** + * The result of an evaluation metric + */ + interface Result extends ToXContentObject { + + /** + * Returns the name of the metric + */ + String getMetricName(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java new file mode 100644 index 0000000000000..764ff41de86e0 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -0,0 +1,57 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; + +import java.util.Arrays; +import java.util.List; + +public class MlEvaluationNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + // Evaluations + new NamedXContentRegistry.Entry( + Evaluation.class, new ParseField(BinarySoftClassification.NAME), BinarySoftClassification::fromXContent), + // Evaluation metrics + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(AucRocMetric.NAME), AucRocMetric::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric::fromXContent), + new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallMetric.NAME), RecallMetric::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric::fromXContent), + // Evaluation metrics results + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(AucRocMetric.NAME), AucRocMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(PrecisionMetric.NAME), PrecisionMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(RecallMetric.NAME), RecallMetric.Result::fromXContent), + new NamedXContentRegistry.Entry( + EvaluationMetric.Result.class, new ParseField(ConfusionMatrixMetric.NAME), ConfusionMatrixMetric.Result::fromXContent)); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java new file mode 100644 index 0000000000000..f41c13f248ab9 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +abstract class AbstractConfusionMatrixMetric implements EvaluationMetric { + + protected static final ParseField AT = new ParseField("at"); + + protected final double[] thresholds; + + protected AbstractConfusionMatrixMetric(List at) { + this.thresholds = Objects.requireNonNull(at).stream().mapToDouble(Double::doubleValue).toArray(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(AT.getPreferredName(), thresholds) + .endObject(); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java new file mode 100644 index 0000000000000..78c713c592581 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/AucRocMetric.java @@ -0,0 +1,241 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + */ +public class AucRocMetric implements EvaluationMetric { + + public static final String NAME = "auc_roc"; + + public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new AucRocMetric((Boolean) args[0])); + + static { + PARSER.declareBoolean(optionalConstructorArg(), INCLUDE_CURVE); + } + + public static AucRocMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static AucRocMetric withCurve() { + return new AucRocMetric(true); + } + + private final boolean includeCurve; + + public AucRocMetric(Boolean includeCurve) { + this.includeCurve = includeCurve == null ? false : includeCurve; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder + .startObject() + .field(INCLUDE_CURVE.getPreferredName(), includeCurve) + .endObject(); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocMetric that = (AucRocMetric) o; + return Objects.equals(includeCurve, that.includeCurve); + } + + @Override + public int hashCode() { + return Objects.hash(includeCurve); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField SCORE = new ParseField("score"); + private static final ParseField CURVE = new ParseField("curve"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("auc_roc_result", true, args -> new Result((double) args[0], (List) args[1])); + + static { + PARSER.declareDouble(constructorArg(), SCORE); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> AucRocPoint.fromXContent(p), CURVE); + } + + private final double score; + private final List curve; + + public Result(double score, @Nullable List curve) { + this.score = score; + this.curve = curve; + } + + @Override + public String getMetricName() { + return NAME; + } + + public double getScore() { + return score; + } + + public List getCurve() { + return curve == null ? null : Collections.unmodifiableList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(SCORE.getPreferredName(), score); + if (curve != null && curve.isEmpty() == false) { + builder.field(CURVE.getPreferredName(), curve); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(score, that.score) + && Objects.equals(curve, that.curve); + } + + @Override + public int hashCode() { + return Objects.hash(score, curve); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static final class AucRocPoint implements ToXContentObject { + + public static AucRocPoint fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField TPR = new ParseField("tpr"); + private static final ParseField FPR = new ParseField("fpr"); + private static final ParseField THRESHOLD = new ParseField("threshold"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "auc_roc_point", + true, + args -> new AucRocPoint((double) args[0], (double) args[1], (double) args[2])); + + static { + PARSER.declareDouble(constructorArg(), TPR); + PARSER.declareDouble(constructorArg(), FPR); + PARSER.declareDouble(constructorArg(), THRESHOLD); + } + + private final double tpr; + private final double fpr; + private final double threshold; + + public AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + public double getTruePositiveRate() { + return tpr; + } + + public double getFalsePositiveRate() { + return fpr; + } + + public double getThreshold() { + return threshold; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(TPR.getPreferredName(), tpr) + .field(FPR.getPreferredName(), fpr) + .field(THRESHOLD.getPreferredName(), threshold) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRocPoint that = (AucRocPoint) o; + return tpr == that.tpr && fpr == that.fpr && threshold == that.threshold; + } + + @Override + public int hashCode() { + return Objects.hash(tpr, fpr, threshold); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java new file mode 100644 index 0000000000000..6d5fa04da38e5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -0,0 +1,129 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +/** + * Evaluation of binary soft classification methods, e.g. outlier detection. + * This is useful to evaluate problems where a model outputs a probability of whether + * a data frame row belongs to one of two groups. + */ +public class BinarySoftClassification implements Evaluation { + + public static final String NAME = "binary_soft_classification"; + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + NAME, + args -> new BinarySoftClassification((String) args[0], (String) args[1], (List) args[2])); + + static { + PARSER.declareString(constructorArg(), ACTUAL_FIELD); + PARSER.declareString(constructorArg(), PREDICTED_PROBABILITY_FIELD); + PARSER.declareNamedObjects(optionalConstructorArg(), (p, c, n) -> p.namedObject(EvaluationMetric.class, n, null), METRICS); + } + + public static BinarySoftClassification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field where the actual class is marked up. + * The value of this field is assumed to either be 1 or 0, or true or false. + */ + private final String actualField; + + /** + * The field of the predicted probability in [0.0, 1.0]. + */ + private final String predictedProbabilityField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public BinarySoftClassification(String actualField, String predictedProbabilityField, EvaluationMetric... metric) { + this(actualField, predictedProbabilityField, Arrays.asList(metric)); + } + + public BinarySoftClassification(String actualField, String predictedProbabilityField, + @Nullable List metrics) { + this.actualField = Objects.requireNonNull(actualField); + this.predictedProbabilityField = Objects.requireNonNull(predictedProbabilityField); + this.metrics = Objects.requireNonNull(metrics); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); + + builder.startObject(METRICS.getPreferredName()); + for (EvaluationMetric metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BinarySoftClassification that = (BinarySoftClassification) o; + return Objects.equals(actualField, that.actualField) + && Objects.equals(predictedProbabilityField, that.predictedProbabilityField) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedProbabilityField, metrics); + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java new file mode 100644 index 0000000000000..d5e4307c9cc74 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/ConfusionMatrixMetric.java @@ -0,0 +1,206 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class ConfusionMatrixMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "confusion_matrix"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new ConfusionMatrixMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static ConfusionMatrixMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static ConfusionMatrixMetric at(Double... at) { + return new ConfusionMatrixMetric(Arrays.asList(at)); + } + + public ConfusionMatrixMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrixMetric that = (ConfusionMatrixMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, ConfusionMatrix::fromXContent)); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public ConfusionMatrix getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static final class ConfusionMatrix implements ToXContentObject { + + public static ConfusionMatrix fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private static final ParseField TP = new ParseField("tp"); + private static final ParseField FP = new ParseField("fp"); + private static final ParseField TN = new ParseField("tn"); + private static final ParseField FN = new ParseField("fn"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "confusion_matrix", true, args -> new ConfusionMatrix((long) args[0], (long) args[1], (long) args[2], (long) args[3])); + + static { + PARSER.declareLong(constructorArg(), TP); + PARSER.declareLong(constructorArg(), FP); + PARSER.declareLong(constructorArg(), TN); + PARSER.declareLong(constructorArg(), FN); + } + + private final long tp; + private final long fp; + private final long tn; + private final long fn; + + public ConfusionMatrix(long tp, long fp, long tn, long fn) { + this.tp = tp; + this.fp = fp; + this.tn = tn; + this.fn = fn; + } + + public long getTruePositives() { + return tp; + } + + public long getFalsePositives() { + return fp; + } + + public long getTrueNegatives() { + return tn; + } + + public long getFalseNegatives() { + return fn; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(TP.getPreferredName(), tp) + .field(FP.getPreferredName(), fp) + .field(TN.getPreferredName(), tn) + .field(FN.getPreferredName(), fn) + .endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrix that = (ConfusionMatrix) o; + return tp == that.tp && fp == that.fp && tn == that.tn && fn == that.fn; + } + + @Override + public int hashCode() { + return Objects.hash(tp, fp, tn, fn); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java new file mode 100644 index 0000000000000..2a0f1499461d6 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/PrecisionMetric.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class PrecisionMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "precision"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new PrecisionMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static PrecisionMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static PrecisionMetric at(Double... at) { + return new PrecisionMetric(Arrays.asList(at)); + } + + public PrecisionMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PrecisionMetric that = (PrecisionMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, p -> p.doubleValue())); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public Double getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java new file mode 100644 index 0000000000000..505ff1b34d7c5 --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/softclassification/RecallMetric.java @@ -0,0 +1,123 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class RecallMetric extends AbstractConfusionMatrixMetric { + + public static final String NAME = "recall"; + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, args -> new RecallMetric((List) args[0])); + + static { + PARSER.declareDoubleArray(constructorArg(), AT); + } + + public static RecallMetric fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public static RecallMetric at(Double... at) { + return new RecallMetric(Arrays.asList(at)); + } + + public RecallMetric(List at) { + super(at); + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecallMetric that = (RecallMetric) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + public static class Result implements EvaluationMetric.Result { + + public static Result fromXContent(XContentParser parser) throws IOException { + return new Result(parser.map(LinkedHashMap::new, p -> p.doubleValue())); + } + + private final Map results; + + public Result(Map results) { + this.results = Objects.requireNonNull(results); + } + + @Override + public String getMetricName() { + return NAME; + } + + public Double getScoreByThreshold(String threshold) { + return results.get(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.map(results); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Result that = (Result) o; + return Objects.equals(results, that.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } +} diff --git a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider index 4204a868246a5..dde81e43867d8 100644 --- a/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider +++ b/client/rest-high-level/src/main/resources/META-INF/services/org.elasticsearch.plugins.spi.NamedXContentProvider @@ -1 +1,4 @@ -org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider \ No newline at end of file +org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider +org.elasticsearch.client.indexlifecycle.IndexLifecycleNamedXContentProvider +org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider +org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider \ No newline at end of file diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java index 36e19896ea3c8..26a4ade504682 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/DataFrameRequestConvertersTests.java @@ -23,6 +23,7 @@ import org.apache.http.client.methods.HttpGet; import org.apache.http.client.methods.HttpPost; import org.apache.http.client.methods.HttpPut; +import org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider; import org.elasticsearch.client.core.PageParams; import org.elasticsearch.client.dataframe.DeleteDataFrameTransformRequest; import org.elasticsearch.client.dataframe.GetDataFrameTransformRequest; @@ -43,6 +44,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; @@ -53,7 +55,10 @@ public class DataFrameRequestConvertersTests extends ESTestCase { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } public void testPutDataFrameTransform() throws IOException { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java index fd867a12204d0..9bb2bb42fd9d7 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java @@ -28,12 +28,14 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteFilterRequest; import org.elasticsearch.client.ml.DeleteForecastRequest; import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureRequestTests; import org.elasticsearch.client.ml.FlushJobRequest; @@ -42,6 +44,8 @@ import org.elasticsearch.client.ml.GetCalendarEventsRequest; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCategoriesRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; import org.elasticsearch.client.ml.GetFiltersRequest; @@ -58,13 +62,16 @@ import org.elasticsearch.client.ml.PreviewDatafeedRequest; import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutFilterRequest; import org.elasticsearch.client.ml.PutJobRequest; import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedRequestTests; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.UpdateFilterRequest; import org.elasticsearch.client.ml.UpdateJobRequest; @@ -75,6 +82,12 @@ import org.elasticsearch.client.ml.calendars.ScheduledEventTests; import org.elasticsearch.client.ml.datafeed.DatafeedConfig; import org.elasticsearch.client.ml.datafeed.DatafeedConfigTests; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.Detector; @@ -84,23 +97,30 @@ import org.elasticsearch.client.ml.job.config.MlFilter; import org.elasticsearch.client.ml.job.config.MlFilterTests; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.ESTestCase; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfigTests.randomDataFrameAnalyticsConfig; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.is; import static org.hamcrest.core.IsNull.nullValue; @@ -154,7 +174,6 @@ public void testGetJobStats() { assertEquals(Boolean.toString(true), request.getParameters().get("allow_no_jobs")); } - public void testOpenJob() throws Exception { String jobId = "some-job-id"; OpenJobRequest openJobRequest = new OpenJobRequest(jobId); @@ -669,6 +688,109 @@ public void testDeleteCalendarEvent() { assertEquals("/_ml/calendars/" + calendarId + "/events/" + eventId, request.getEndpoint()); } + public void testPutDataFrameAnalytics() throws IOException { + PutDataFrameAnalyticsRequest putRequest = new PutDataFrameAnalyticsRequest(randomDataFrameAnalyticsConfig()); + Request request = MLRequestConverters.putDataFrameAnalytics(putRequest); + assertEquals(HttpPut.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + putRequest.getConfig().getId(), request.getEndpoint()); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) { + DataFrameAnalyticsConfig parsedConfig = DataFrameAnalyticsConfig.fromXContent(parser); + assertThat(parsedConfig, equalTo(putRequest.getConfig())); + } + } + + public void testGetDataFrameAnalytics() { + String configId1 = randomAlphaOfLength(10); + String configId2 = randomAlphaOfLength(10); + String configId3 = randomAlphaOfLength(10); + GetDataFrameAnalyticsRequest getRequest = new GetDataFrameAnalyticsRequest(configId1, configId2, configId3) + .setAllowNoMatch(false) + .setPageParams(new PageParams(100, 300)); + + Request request = MLRequestConverters.getDataFrameAnalytics(getRequest); + assertEquals(HttpGet.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + configId1 + "," + configId2 + "," + configId3, request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false"))); + assertNull(request.getEntity()); + } + + public void testGetDataFrameAnalyticsStats() { + String configId1 = randomAlphaOfLength(10); + String configId2 = randomAlphaOfLength(10); + String configId3 = randomAlphaOfLength(10); + GetDataFrameAnalyticsStatsRequest getStatsRequest = new GetDataFrameAnalyticsStatsRequest(configId1, configId2, configId3) + .setAllowNoMatch(false) + .setPageParams(new PageParams(100, 300)); + + Request request = MLRequestConverters.getDataFrameAnalyticsStats(getStatsRequest); + assertEquals(HttpGet.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + configId1 + "," + configId2 + "," + configId3 + "/_stats", request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("from", "100"), hasEntry("size", "300"), hasEntry("allow_no_match", "false"))); + assertNull(request.getEntity()); + } + + public void testStartDataFrameAnalytics() { + StartDataFrameAnalyticsRequest startRequest = new StartDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.startDataFrameAnalytics(startRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + startRequest.getId() + "/_start", request.getEndpoint()); + assertNull(request.getEntity()); + } + + public void testStartDataFrameAnalytics_WithTimeout() { + StartDataFrameAnalyticsRequest startRequest = new StartDataFrameAnalyticsRequest(randomAlphaOfLength(10)) + .setTimeout(TimeValue.timeValueMinutes(1)); + Request request = MLRequestConverters.startDataFrameAnalytics(startRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + startRequest.getId() + "/_start", request.getEndpoint()); + assertThat(request.getParameters(), hasEntry("timeout", "1m")); + assertNull(request.getEntity()); + } + + public void testStopDataFrameAnalytics() { + StopDataFrameAnalyticsRequest stopRequest = new StopDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.stopDataFrameAnalytics(stopRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + stopRequest.getId() + "/_stop", request.getEndpoint()); + assertNull(request.getEntity()); + } + + public void testStopDataFrameAnalytics_WithParams() { + StopDataFrameAnalyticsRequest stopRequest = new StopDataFrameAnalyticsRequest(randomAlphaOfLength(10)) + .setTimeout(TimeValue.timeValueMinutes(1)) + .setAllowNoMatch(false); + Request request = MLRequestConverters.stopDataFrameAnalytics(stopRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + stopRequest.getId() + "/_stop", request.getEndpoint()); + assertThat(request.getParameters(), allOf(hasEntry("timeout", "1m"), hasEntry("allow_no_match", "false"))); + assertNull(request.getEntity()); + } + + public void testDeleteDataFrameAnalytics() { + DeleteDataFrameAnalyticsRequest deleteRequest = new DeleteDataFrameAnalyticsRequest(randomAlphaOfLength(10)); + Request request = MLRequestConverters.deleteDataFrameAnalytics(deleteRequest); + assertEquals(HttpDelete.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/analytics/" + deleteRequest.getId(), request.getEndpoint()); + assertNull(request.getEntity()); + } + + public void testEvaluateDataFrame() throws IOException { + EvaluateDataFrameRequest evaluateRequest = + new EvaluateDataFrameRequest( + Arrays.asList(generateRandomStringArray(1, 10, false, false)), + new BinarySoftClassification( + randomAlphaOfLengthBetween(1, 10), + randomAlphaOfLengthBetween(1, 10), + PrecisionMetric.at(0.5), RecallMetric.at(0.6, 0.7))); + Request request = MLRequestConverters.evaluateDataFrame(evaluateRequest); + assertEquals(HttpPost.METHOD_NAME, request.getMethod()); + assertEquals("/_ml/data_frame/_evaluate", request.getEndpoint()); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, request.getEntity().getContent())) { + EvaluateDataFrameRequest parsedRequest = EvaluateDataFrameRequest.fromXContent(parser); + assertThat(parsedRequest, equalTo(evaluateRequest)); + } + } + public void testPutFilter() throws IOException { MlFilter filter = MlFilterTests.createRandomBuilder("foo").build(); PutFilterRequest putFilterRequest = new PutFilterRequest(filter); @@ -835,6 +957,15 @@ public void testSetUpgradeMode() { assertThat(request.getParameters().get(SetUpgradeModeRequest.TIMEOUT.getPreferredName()), is("1h")); } + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } + private static Job createValidJob(String jobId) { AnalysisConfig.Builder analysisConfig = AnalysisConfig.builder(Collections.singletonList( Detector.builder().setFunction("count").build())); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 8ef28733f2e12..77efe43b2e174 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -29,11 +29,13 @@ import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.client.core.PageParams; import org.elasticsearch.client.indices.CreateIndexRequest; +import org.elasticsearch.client.indices.GetIndexRequest; import org.elasticsearch.client.ml.CloseJobRequest; import org.elasticsearch.client.ml.CloseJobResponse; import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -42,6 +44,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -52,6 +56,10 @@ import org.elasticsearch.client.ml.GetCalendarEventsResponse; import org.elasticsearch.client.ml.GetCalendarsRequest; import org.elasticsearch.client.ml.GetCalendarsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetDatafeedStatsRequest; @@ -77,6 +85,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -86,8 +96,11 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -103,6 +116,18 @@ import org.elasticsearch.client.ml.datafeed.DatafeedState; import org.elasticsearch.client.ml.datafeed.DatafeedStats; import org.elasticsearch.client.ml.datafeed.DatafeedUpdate; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.DataDescription; @@ -113,9 +138,12 @@ import org.elasticsearch.client.ml.job.config.MlFilter; import org.elasticsearch.client.ml.job.process.ModelSnapshot; import org.elasticsearch.client.ml.job.stats.JobStats; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.junit.After; @@ -136,6 +164,7 @@ import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.CoreMatchers.hasItems; import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -528,18 +557,7 @@ public void testStartDatafeed() throws Exception { String indexName = "start_data_1"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); long now = (System.currentTimeMillis()/1000)*1000; @@ -611,18 +629,7 @@ public void testStopDatafeed() throws Exception { String indexName = "stop_data_1"; // Set up the index - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); // create the job and the datafeed Job job1 = buildJob(jobId1); @@ -684,18 +691,7 @@ public void testGetDatafeedStats() throws Exception { String indexName = "datafeed_stats_data_1"; // Set up the index - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); // create the job and the datafeed Job job1 = buildJob(jobId1); @@ -762,18 +758,7 @@ public void testPreviewDatafeed() throws Exception { String indexName = "preview_data_1"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); long now = (System.currentTimeMillis()/1000)*1000; @@ -826,21 +811,9 @@ public void testDeleteExpiredDataGivenNothingToDelete() throws Exception { } private String createExpiredData(String jobId) throws Exception { - String indexId = jobId + "-data"; + String indexName = jobId + "-data"; // Set up the index and docs - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexId); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .field("format", "epoch_millis") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName, defaultMappingForTest()); BulkRequest bulk = new BulkRequest(); bulk.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -853,7 +826,7 @@ private String createExpiredData(String jobId) throws Exception { long timestamp = nowMillis - TimeValue.timeValueHours(totalBuckets - bucket).getMillis(); int bucketRate = bucket == anomalousBucket ? anomalousRate : normalRate; for (int point = 0; point < bucketRate; point++) { - IndexRequest indexRequest = new IndexRequest(indexId); + IndexRequest indexRequest = new IndexRequest(indexName); indexRequest.source(XContentType.JSON, "timestamp", timestamp, "total", randomInt(1000)); bulk.add(indexRequest); } @@ -872,7 +845,7 @@ private String createExpiredData(String jobId) throws Exception { Job job = buildJobForExpiredDataTests(jobId); putJob(job); openJob(job); - String datafeedId = createAndPutDatafeed(jobId, indexId); + String datafeedId = createAndPutDatafeed(jobId, indexName); startDatafeed(datafeedId, String.valueOf(0), String.valueOf(nowMillis - TimeValue.timeValueHours(24).getMillis())); @@ -1230,6 +1203,418 @@ public void testDeleteCalendarEvent() throws IOException { assertThat(remainingIds, not(hasItem(deletedEvent))); } + public void testPutDataFrameAnalyticsConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "put-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("put-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("put-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + createIndex("put-test-source-index", defaultMappingForTest()); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + assertThat(createdConfig.getId(), equalTo(config.getId())); + assertThat(createdConfig.getSource().getIndex(), equalTo(config.getSource().getIndex())); + assertThat(createdConfig.getSource().getQueryConfig(), equalTo(new QueryConfig(new MatchAllQueryBuilder()))); // default value + assertThat(createdConfig.getDest().getIndex(), equalTo(config.getDest().getIndex())); + assertThat(createdConfig.getDest().getResultsField(), equalTo("ml")); // default value + assertThat(createdConfig.getAnalysis(), equalTo(config.getAnalysis())); + assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields())); + assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value + } + + public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "get-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("get-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("get-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + createIndex("get-test-source-index", defaultMappingForTest()); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(1)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), contains(createdConfig)); + } + + public void testGetDataFrameAnalyticsConfig_MultipleConfigs() throws Exception { + createIndex("get-test-source-index", defaultMappingForTest()); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configIdPrefix = "get-test-config-"; + int numberOfConfigs = 10; + List createdConfigs = new ArrayList<>(); + for (int i = 0; i < numberOfConfigs; ++i) { + String configId = configIdPrefix + i; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("get-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("get-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + PutDataFrameAnalyticsResponse putDataFrameAnalyticsResponse = execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + DataFrameAnalyticsConfig createdConfig = putDataFrameAnalyticsResponse.getConfig(); + createdConfigs.add(createdConfig); + } + + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + GetDataFrameAnalyticsRequest.getAllDataFrameAnalyticsRequest(), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(numberOfConfigs)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), containsInAnyOrder(createdConfigs.toArray())); + } + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configIdPrefix + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(numberOfConfigs)); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), containsInAnyOrder(createdConfigs.toArray())); + } + { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configIdPrefix + "9", configIdPrefix + "1", configIdPrefix + "4"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(3)); + assertThat( + getDataFrameAnalyticsResponse.getAnalytics(), + containsInAnyOrder(createdConfigs.get(1), createdConfigs.get(4), createdConfigs.get(9))); + } + { + GetDataFrameAnalyticsRequest getDataFrameAnalyticsRequest = new GetDataFrameAnalyticsRequest(configIdPrefix + "*"); + getDataFrameAnalyticsRequest.setPageParams(new PageParams(3, 4)); + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + getDataFrameAnalyticsRequest, + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(4)); + assertThat( + getDataFrameAnalyticsResponse.getAnalytics(), + containsInAnyOrder(createdConfigs.get(3), createdConfigs.get(4), createdConfigs.get(5), createdConfigs.get(6))); + } + } + + public void testGetDataFrameAnalyticsConfig_ConfigNotFound() { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("config_that_does_not_exist"); + ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, + () -> execute(request, machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync)); + assertThat(exception.status().getStatus(), equalTo(404)); + } + + public void testGetDataFrameAnalyticsStats() throws Exception { + String sourceIndex = "get-stats-test-source-index"; + String destIndex = "get-stats-test-dest-index"; + createIndex(sourceIndex, defaultMappingForTest()); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "get-stats-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + + GetDataFrameAnalyticsStatsResponse statsResponse = execute( + new GetDataFrameAnalyticsStatsRequest(configId), + machineLearningClient::getDataFrameAnalyticsStats, machineLearningClient::getDataFrameAnalyticsStatsAsync); + + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + assertThat(stats.getId(), equalTo(configId)); + assertThat(stats.getState(), equalTo(DataFrameAnalyticsState.STOPPED)); + assertNull(stats.getProgressPercent()); + assertNull(stats.getNode()); + assertNull(stats.getAssignmentExplanation()); + assertThat(statsResponse.getNodeFailures(), hasSize(0)); + assertThat(statsResponse.getTaskFailures(), hasSize(0)); + } + + public void testStartDataFrameAnalyticsConfig() throws Exception { + String sourceIndex = "start-test-source-index"; + String destIndex = "start-test-dest-index"; + createIndex(sourceIndex, defaultMappingForTest()); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); + + // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. + assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "start-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + + AcknowledgedResponse startDataFrameAnalyticsResponse = execute( + new StartDataFrameAnalyticsRequest(configId), + machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync); + assertTrue(startDataFrameAnalyticsResponse.isAcknowledged()); + + // Wait for the analytics to stop. + assertBusy(() -> assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)), 30, TimeUnit.SECONDS); + + // Verify that the destination index got created. + assertTrue(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + } + + public void testStopDataFrameAnalyticsConfig() throws Exception { + String sourceIndex = "stop-test-source-index"; + String destIndex = "stop-test-dest-index"; + createIndex(sourceIndex, mappingForClassification()); + highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT); + + // Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail. + assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT)); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "stop-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex(sourceIndex) + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex(destIndex) + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + + AcknowledgedResponse startDataFrameAnalyticsResponse = execute( + new StartDataFrameAnalyticsRequest(configId), + machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync); + assertTrue(startDataFrameAnalyticsResponse.isAcknowledged()); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STARTED)); + + StopDataFrameAnalyticsResponse stopDataFrameAnalyticsResponse = execute( + new StopDataFrameAnalyticsRequest(configId), + machineLearningClient::stopDataFrameAnalytics, machineLearningClient::stopDataFrameAnalyticsAsync); + assertTrue(stopDataFrameAnalyticsResponse.isStopped()); + assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)); + } + + private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOException { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + GetDataFrameAnalyticsStatsResponse statsResponse = + machineLearningClient.getDataFrameAnalyticsStats(new GetDataFrameAnalyticsStatsRequest(configId), RequestOptions.DEFAULT); + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + return stats.getState(); + } + + public void testDeleteDataFrameAnalyticsConfig() throws Exception { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + String configId = "delete-test-config"; + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder(configId) + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("delete-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("delete-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); + + createIndex("delete-test-source-index", defaultMappingForTest()); + + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(0)); + + execute( + new PutDataFrameAnalyticsRequest(config), + machineLearningClient::putDataFrameAnalytics, machineLearningClient::putDataFrameAnalyticsAsync); + + getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(1)); + + AcknowledgedResponse deleteDataFrameAnalyticsResponse = execute( + new DeleteDataFrameAnalyticsRequest(configId), + machineLearningClient::deleteDataFrameAnalytics, machineLearningClient::deleteDataFrameAnalyticsAsync); + assertTrue(deleteDataFrameAnalyticsResponse.isAcknowledged()); + + getDataFrameAnalyticsResponse = execute( + new GetDataFrameAnalyticsRequest(configId + "*"), + machineLearningClient::getDataFrameAnalytics, machineLearningClient::getDataFrameAnalyticsAsync); + assertThat(getDataFrameAnalyticsResponse.getAnalytics(), hasSize(0)); + } + + public void testDeleteDataFrameAnalyticsConfig_ConfigNotFound() { + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("config_that_does_not_exist"); + ElasticsearchStatusException exception = expectThrows(ElasticsearchStatusException.class, + () -> execute( + request, machineLearningClient::deleteDataFrameAnalytics, machineLearningClient::deleteDataFrameAnalyticsAsync)); + assertThat(exception.status().getStatus(), equalTo(404)); + } + + public void testEvaluateDataFrame() throws IOException { + String indexName = "evaluate-test-index"; + createIndex(indexName, mappingForClassification()); + BulkRequest bulk = new BulkRequest() + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(docForClassification(indexName, false, 0.1)) // #0 + .add(docForClassification(indexName, false, 0.2)) // #1 + .add(docForClassification(indexName, false, 0.3)) // #2 + .add(docForClassification(indexName, false, 0.4)) // #3 + .add(docForClassification(indexName, false, 0.7)) // #4 + .add(docForClassification(indexName, true, 0.2)) // #5 + .add(docForClassification(indexName, true, 0.3)) // #6 + .add(docForClassification(indexName, true, 0.4)) // #7 + .add(docForClassification(indexName, true, 0.8)) // #8 + .add(docForClassification(indexName, true, 0.9)); // #9 + highLevelClient().bulk(bulk, RequestOptions.DEFAULT); + + MachineLearningClient machineLearningClient = highLevelClient().machineLearning(); + EvaluateDataFrameRequest evaluateDataFrameRequest = + new EvaluateDataFrameRequest( + indexName, + new BinarySoftClassification( + actualField, + probabilityField, + PrecisionMetric.at(0.4, 0.5, 0.6), RecallMetric.at(0.5, 0.7), ConfusionMatrixMetric.at(0.5), AucRocMetric.withCurve())); + + EvaluateDataFrameResponse evaluateDataFrameResponse = + execute(evaluateDataFrameRequest, machineLearningClient::evaluateDataFrame, machineLearningClient::evaluateDataFrameAsync); + assertThat(evaluateDataFrameResponse.getEvaluationName(), equalTo(BinarySoftClassification.NAME)); + assertThat(evaluateDataFrameResponse.getMetrics().size(), equalTo(4)); + + PrecisionMetric.Result precisionResult = evaluateDataFrameResponse.getMetricByName(PrecisionMetric.NAME); + assertThat(precisionResult.getMetricName(), equalTo(PrecisionMetric.NAME)); + // Precision is 3/5=0.6 as there were 3 true examples (#7, #8, #9) among the 5 positive examples (#3, #4, #7, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.4"), closeTo(0.6, 1e-9)); + // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.5"), closeTo(0.666666666, 1e-9)); + // Precision is 2/3=0.(6) as there were 2 true examples (#8, #9) among the 3 positive examples (#4, #8, #9) + assertThat(precisionResult.getScoreByThreshold("0.6"), closeTo(0.666666666, 1e-9)); + assertNull(precisionResult.getScoreByThreshold("0.1")); + + RecallMetric.Result recallResult = evaluateDataFrameResponse.getMetricByName(RecallMetric.NAME); + assertThat(recallResult.getMetricName(), equalTo(RecallMetric.NAME)); + // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) + assertThat(recallResult.getScoreByThreshold("0.5"), closeTo(0.4, 1e-9)); + // Recall is 2/5=0.4 as there were 2 true positive examples (#8, #9) among the 5 true examples (#5, #6, #7, #8, #9) + assertThat(recallResult.getScoreByThreshold("0.7"), closeTo(0.4, 1e-9)); + assertNull(recallResult.getScoreByThreshold("0.1")); + + ConfusionMatrixMetric.Result confusionMatrixResult = evaluateDataFrameResponse.getMetricByName(ConfusionMatrixMetric.NAME); + assertThat(confusionMatrixResult.getMetricName(), equalTo(ConfusionMatrixMetric.NAME)); + ConfusionMatrixMetric.ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); + assertThat(confusionMatrix.getTruePositives(), equalTo(2L)); // docs #8 and #9 + assertThat(confusionMatrix.getFalsePositives(), equalTo(1L)); // doc #4 + assertThat(confusionMatrix.getTrueNegatives(), equalTo(4L)); // docs #0, #1, #2 and #3 + assertThat(confusionMatrix.getFalseNegatives(), equalTo(3L)); // docs #5, #6 and #7 + assertNull(confusionMatrixResult.getScoreByThreshold("0.1")); + + AucRocMetric.Result aucRocResult = evaluateDataFrameResponse.getMetricByName(AucRocMetric.NAME); + assertThat(aucRocResult.getMetricName(), equalTo(AucRocMetric.NAME)); + assertThat(aucRocResult.getScore(), closeTo(0.70025, 1e-9)); + assertNotNull(aucRocResult.getCurve()); + List curve = aucRocResult.getCurve(); + AucRocMetric.AucRocPoint curvePointAtThreshold0 = curve.stream().filter(p -> p.getThreshold() == 0.0).findFirst().get(); + assertThat(curvePointAtThreshold0.getTruePositiveRate(), equalTo(1.0)); + assertThat(curvePointAtThreshold0.getFalsePositiveRate(), equalTo(1.0)); + assertThat(curvePointAtThreshold0.getThreshold(), equalTo(0.0)); + AucRocMetric.AucRocPoint curvePointAtThreshold1 = curve.stream().filter(p -> p.getThreshold() == 1.0).findFirst().get(); + assertThat(curvePointAtThreshold1.getTruePositiveRate(), equalTo(0.0)); + assertThat(curvePointAtThreshold1.getFalsePositiveRate(), equalTo(0.0)); + assertThat(curvePointAtThreshold1.getThreshold(), equalTo(1.0)); + } + + private static XContentBuilder defaultMappingForTest() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("timestamp") + .field("type", "date") + .endObject() + .startObject("total") + .field("type", "long") + .endObject() + .endObject() + .endObject(); + } + + private static final String actualField = "label"; + private static final String probabilityField = "p"; + + private static XContentBuilder mappingForClassification() throws IOException { + return XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject(actualField) + .field("type", "keyword") + .endObject() + .startObject(probabilityField) + .field("type", "double") + .endObject() + .endObject() + .endObject(); + } + + private static IndexRequest docForClassification(String indexName, boolean isTrue, double p) { + return new IndexRequest() + .index(indexName) + .source(XContentType.JSON, actualField, Boolean.toString(isTrue), probabilityField, p); + } + + private void createIndex(String indexName, XContentBuilder mapping) throws IOException { + highLevelClient().indices().create(new CreateIndexRequest(indexName).mapping(mapping), RequestOptions.DEFAULT); + } + public void testPutFilter() throws Exception { String filterId = "filter-job-test"; MlFilter mlFilter = MlFilter.builder(filterId) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java index c565af7c37202..f5776e99fd0eb 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MlTestStateCleaner.java @@ -20,14 +20,18 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.client.ml.CloseJobRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteJobRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; import org.elasticsearch.client.ml.GetDatafeedResponse; import org.elasticsearch.client.ml.GetJobRequest; import org.elasticsearch.client.ml.GetJobResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.datafeed.DatafeedConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.job.config.Job; import java.io.IOException; @@ -48,6 +52,7 @@ public MlTestStateCleaner(Logger logger, MachineLearningClient mlClient) { public void clearMlMetadata() throws IOException { deleteAllDatafeeds(); deleteAllJobs(); + deleteAllDataFrameAnalytics(); } private void deleteAllDatafeeds() throws IOException { @@ -99,4 +104,12 @@ private void closeAllJobs() { throw new RuntimeException("Had to resort to force-closing jobs, something went wrong?", e1); } } + + private void deleteAllDataFrameAnalytics() throws IOException { + GetDataFrameAnalyticsResponse getDataFrameAnalyticsResponse = + mlClient.getDataFrameAnalytics(GetDataFrameAnalyticsRequest.getAllDataFrameAnalyticsRequest(), RequestOptions.DEFAULT); + for (DataFrameAnalyticsConfig config : getDataFrameAnalyticsResponse.getAnalytics()) { + mlClient.deleteDataFrameAnalytics(new DeleteDataFrameAnalyticsRequest(config.getId()), RequestOptions.DEFAULT); + } + } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 183bce91f83ed..26e5842019675 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -46,6 +46,8 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.core.MainRequest; import org.elasticsearch.client.core.MainResponse; +import org.elasticsearch.client.dataframe.transforms.SyncConfig; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; import org.elasticsearch.client.indexlifecycle.AllocateAction; import org.elasticsearch.client.indexlifecycle.DeleteAction; import org.elasticsearch.client.indexlifecycle.ForceMergeAction; @@ -56,6 +58,13 @@ import org.elasticsearch.client.indexlifecycle.SetPriorityAction; import org.elasticsearch.client.indexlifecycle.ShrinkAction; import org.elasticsearch.client.indexlifecycle.UnfollowAction; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -109,6 +118,7 @@ import static org.hamcrest.CoreMatchers.endsWith; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.Matchers.hasItems; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -664,7 +674,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(20, namedXContents.size()); + assertEquals(31, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -674,7 +684,7 @@ public void testProvidedNamedXContents() { categories.put(namedXContent.categoryClass, counter + 1); } } - assertEquals("Had: " + categories, 4, categories.size()); + assertEquals("Had: " + categories, 9, categories.size()); assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); @@ -698,6 +708,16 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(ShrinkAction.NAME)); assertTrue(names.contains(FreezeAction.NAME)); assertTrue(names.contains(SetPriorityAction.NAME)); + assertEquals(Integer.valueOf(1), categories.get(DataFrameAnalysis.class)); + assertTrue(names.contains(OutlierDetection.NAME.getPreferredName())); + assertEquals(Integer.valueOf(1), categories.get(SyncConfig.class)); + assertTrue(names.contains(TimeSyncConfig.NAME)); + assertEquals(Integer.valueOf(1), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.Evaluation.class)); + assertThat(names, hasItems(BinarySoftClassification.NAME)); + assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.class)); + assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); + assertEquals(Integer.valueOf(4), categories.get(org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric.Result.class)); + assertThat(names, hasItems(AucRocMetric.NAME, PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME)); } public void testApiNamingConventions() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java index 0ba1406b54641..c37e8f8997185 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/GetDataFrameTransformResponseTests.java @@ -79,6 +79,9 @@ private static void toXContent(GetDataFrameTransformResponse response, XContentB @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java index 828f3ee1b9e40..d335e6a497a8e 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PreviewDataFrameTransformRequestTests.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Optional; import static org.elasticsearch.client.dataframe.transforms.SourceConfigTests.randomSourceConfig; @@ -55,7 +56,10 @@ protected boolean supportsUnknownFields() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } public void testValidate() { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java index c72027ee354cc..01e1db2cb3823 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/PutDataFrameTransformRequestTests.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Optional; import static org.hamcrest.Matchers.containsString; @@ -71,6 +72,9 @@ protected boolean supportsUnknownFields() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java index 9146bd39d1336..79b7e85098e04 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/DataFrameTransformConfigTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.client.dataframe.transforms; +import org.elasticsearch.client.dataframe.DataFrameNamedXContentProvider; import org.elasticsearch.Version; import org.elasticsearch.client.dataframe.transforms.pivot.PivotConfigTests; import org.elasticsearch.common.settings.Settings; @@ -30,6 +31,7 @@ import java.io.IOException; import java.time.Instant; import java.util.Collections; +import java.util.List; import java.util.function.Predicate; import static org.elasticsearch.client.dataframe.transforms.DestConfigTests.randomDestConfig; @@ -41,12 +43,17 @@ public static DataFrameTransformConfig randomDataFrameTransformConfig() { return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), randomDestConfig(), + randomBoolean() ? null : randomSyncConfig(), PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100), randomBoolean() ? null : Instant.now(), randomBoolean() ? null : Version.CURRENT.toString()); } + public static SyncConfig randomSyncConfig() { + return TimeSyncConfigTests.randomTimeSyncConfig(); + } + @Override protected DataFrameTransformConfig createTestInstance() { return randomDataFrameTransformConfig(); @@ -71,6 +78,9 @@ protected Predicate getRandomFieldsExcludeFilter() { @Override protected NamedXContentRegistry xContentRegistry() { SearchModule searchModule = new SearchModule(Settings.EMPTY, Collections.emptyList()); - return new NamedXContentRegistry(searchModule.getNamedXContents()); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); + + return new NamedXContentRegistry(namedXContents); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..dd2a17eb0260d --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/TimeSyncConfigTests.java @@ -0,0 +1,49 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class TimeSyncConfigTests extends AbstractXContentTestCase { + + public static TimeSyncConfig randomTimeSyncConfig() { + return new TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), new TimeValue(randomNonNegativeLong())); + } + + @Override + protected TimeSyncConfig createTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected TimeSyncConfig doParseInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..0c6a0350882a4 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/dataframe/transforms/hlrc/TimeSyncConfigTests.java @@ -0,0 +1,59 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.dataframe.transforms.hlrc; + +import org.elasticsearch.client.AbstractResponseTestCase; +import org.elasticsearch.client.dataframe.transforms.TimeSyncConfig; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; + +public class TimeSyncConfigTests + extends AbstractResponseTestCase { + + public static org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig randomTimeSyncConfig() { + return new org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), + new TimeValue(randomNonNegativeLong())); + } + + public static void assertHlrcEquals(org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig serverTestInstance, + TimeSyncConfig clientInstance) { + assertEquals(serverTestInstance.getField(), clientInstance.getField()); + assertEquals(serverTestInstance.getDelay(), clientInstance.getDelay()); + } + + @Override + protected org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig createServerTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected TimeSyncConfig doParseToClientInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser); + } + + @Override + protected void assertInstances(org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig serverTestInstance, + TimeSyncConfig clientInstance) { + assertHlrcEquals(serverTestInstance, clientInstance); + } + +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java index 4f94db604f147..b3fa85880b465 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/DataFrameTransformDocumentationIT.java @@ -433,6 +433,7 @@ public void testPreview() throws IOException, InterruptedException { .setQueryConfig(queryConfig) .build(), // <1> pivotConfig); // <2> + PreviewDataFrameTransformRequest request = new PreviewDataFrameTransformRequest(transformConfig); // <3> // end::preview-data-frame-transform-request diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index fe7d04a4e0a8d..f33926582e3ea 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -39,6 +39,7 @@ import org.elasticsearch.client.ml.DeleteCalendarEventRequest; import org.elasticsearch.client.ml.DeleteCalendarJobRequest; import org.elasticsearch.client.ml.DeleteCalendarRequest; +import org.elasticsearch.client.ml.DeleteDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.DeleteDatafeedRequest; import org.elasticsearch.client.ml.DeleteExpiredDataRequest; import org.elasticsearch.client.ml.DeleteExpiredDataResponse; @@ -47,6 +48,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest; import org.elasticsearch.client.ml.DeleteJobResponse; import org.elasticsearch.client.ml.DeleteModelSnapshotRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameRequest; +import org.elasticsearch.client.ml.EvaluateDataFrameResponse; import org.elasticsearch.client.ml.FindFileStructureRequest; import org.elasticsearch.client.ml.FindFileStructureResponse; import org.elasticsearch.client.ml.FlushJobRequest; @@ -61,6 +64,10 @@ import org.elasticsearch.client.ml.GetCalendarsResponse; import org.elasticsearch.client.ml.GetCategoriesRequest; import org.elasticsearch.client.ml.GetCategoriesResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsResponse; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsRequest; +import org.elasticsearch.client.ml.GetDataFrameAnalyticsStatsResponse; import org.elasticsearch.client.ml.GetModelSnapshotsRequest; import org.elasticsearch.client.ml.GetModelSnapshotsResponse; import org.elasticsearch.client.ml.GetDatafeedRequest; @@ -92,6 +99,8 @@ import org.elasticsearch.client.ml.PutCalendarJobRequest; import org.elasticsearch.client.ml.PutCalendarRequest; import org.elasticsearch.client.ml.PutCalendarResponse; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.PutDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.PutDatafeedRequest; import org.elasticsearch.client.ml.PutDatafeedResponse; import org.elasticsearch.client.ml.PutFilterRequest; @@ -101,8 +110,11 @@ import org.elasticsearch.client.ml.RevertModelSnapshotRequest; import org.elasticsearch.client.ml.RevertModelSnapshotResponse; import org.elasticsearch.client.ml.SetUpgradeModeRequest; +import org.elasticsearch.client.ml.StartDataFrameAnalyticsRequest; import org.elasticsearch.client.ml.StartDatafeedRequest; import org.elasticsearch.client.ml.StartDatafeedResponse; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsRequest; +import org.elasticsearch.client.ml.StopDataFrameAnalyticsResponse; import org.elasticsearch.client.ml.StopDatafeedRequest; import org.elasticsearch.client.ml.StopDatafeedResponse; import org.elasticsearch.client.ml.UpdateDatafeedRequest; @@ -118,6 +130,21 @@ import org.elasticsearch.client.ml.datafeed.DatafeedStats; import org.elasticsearch.client.ml.datafeed.DatafeedUpdate; import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsStats; +import org.elasticsearch.client.ml.dataframe.OutlierDetection; +import org.elasticsearch.client.ml.dataframe.QueryConfig; +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric.ConfusionMatrix; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; import org.elasticsearch.client.ml.filestructurefinder.FileStructure; import org.elasticsearch.client.ml.job.config.AnalysisConfig; import org.elasticsearch.client.ml.job.config.AnalysisLimits; @@ -139,13 +166,18 @@ import org.elasticsearch.client.ml.job.results.OverallBucket; import org.elasticsearch.client.ml.job.stats.JobStats; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.tasks.TaskId; +import org.hamcrest.CoreMatchers; import org.junit.After; import java.io.IOException; @@ -870,18 +902,7 @@ public void testPreviewDatafeed() throws Exception { client.machineLearning().putJob(new PutJobRequest(job), RequestOptions.DEFAULT); String datafeedId = job.getId() + "-feed"; String indexName = "preview_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId, job.getId()) .setIndices(indexName) .build(); @@ -938,18 +959,7 @@ public void testStartDatafeed() throws Exception { client.machineLearning().putJob(new PutJobRequest(job), RequestOptions.DEFAULT); String datafeedId = job.getId() + "-feed"; String indexName = "start_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId, job.getId()) .setIndices(indexName) .build(); @@ -1067,18 +1077,7 @@ public void testGetDatafeedStats() throws Exception { client.machineLearning().putJob(new PutJobRequest(secondJob), RequestOptions.DEFAULT); String datafeedId1 = job.getId() + "-feed"; String indexName = "datafeed_stats_data_2"; - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); - createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() - .startObject("properties") - .startObject("timestamp") - .field("type", "date") - .endObject() - .startObject("total") - .field("type", "long") - .endObject() - .endObject() - .endObject()); - highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + createIndex(indexName); DatafeedConfig datafeed = DatafeedConfig.builder(datafeedId1, job.getId()) .setIndices(indexName) .build(); @@ -2802,6 +2801,463 @@ public void onFailure(Exception e) { } } + public void testGetDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::get-data-frame-analytics-request + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::get-data-frame-analytics-request + + // tag::get-data-frame-analytics-execute + GetDataFrameAnalyticsResponse response = client.machineLearning().getDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::get-data-frame-analytics-execute + + // tag::get-data-frame-analytics-response + List configs = response.getAnalytics(); + // end::get-data-frame-analytics-response + + assertThat(configs.size(), equalTo(1)); + } + { + GetDataFrameAnalyticsRequest request = new GetDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::get-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(GetDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::get-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::get-data-frame-analytics-execute-async + client.machineLearning().getDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::get-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testGetDataFrameAnalyticsStats() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::get-data-frame-analytics-stats-request + GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config"); // <1> + // end::get-data-frame-analytics-stats-request + + // tag::get-data-frame-analytics-stats-execute + GetDataFrameAnalyticsStatsResponse response = + client.machineLearning().getDataFrameAnalyticsStats(request, RequestOptions.DEFAULT); + // end::get-data-frame-analytics-stats-execute + + // tag::get-data-frame-analytics-stats-response + List stats = response.getAnalyticsStats(); + // end::get-data-frame-analytics-stats-response + + assertThat(stats.size(), equalTo(1)); + } + { + GetDataFrameAnalyticsStatsRequest request = new GetDataFrameAnalyticsStatsRequest("my-analytics-config"); + + // tag::get-data-frame-analytics-stats-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(GetDataFrameAnalyticsStatsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::get-data-frame-analytics-stats-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::get-data-frame-analytics-stats-execute-async + client.machineLearning().getDataFrameAnalyticsStatsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::get-data-frame-analytics-stats-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testPutDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + + RestHighLevelClient client = highLevelClient(); + { + // tag::put-data-frame-analytics-query-config + QueryConfig queryConfig = new QueryConfig(new MatchAllQueryBuilder()); + // end::put-data-frame-analytics-query-config + + // tag::put-data-frame-analytics-source-config + DataFrameAnalyticsSource sourceConfig = DataFrameAnalyticsSource.builder() // <1> + .setIndex("put-test-source-index") // <2> + .setQueryConfig(queryConfig) // <3> + .build(); + // end::put-data-frame-analytics-source-config + + // tag::put-data-frame-analytics-dest-config + DataFrameAnalyticsDest destConfig = DataFrameAnalyticsDest.builder() // <1> + .setIndex("put-test-dest-index") // <2> + .build(); + // end::put-data-frame-analytics-dest-config + + // tag::put-data-frame-analytics-analysis-default + DataFrameAnalysis outlierDetection = OutlierDetection.createDefault(); // <1> + // end::put-data-frame-analytics-analysis-default + + // tag::put-data-frame-analytics-analysis-customized + DataFrameAnalysis outlierDetectionCustomized = OutlierDetection.builder() // <1> + .setMethod(OutlierDetection.Method.DISTANCE_KNN) // <2> + .setNNeighbors(5) // <3> + .build(); + // end::put-data-frame-analytics-analysis-customized + + // tag::put-data-frame-analytics-analyzed-fields + FetchSourceContext analyzedFields = + new FetchSourceContext( + true, + new String[] { "included_field_1", "included_field_2" }, + new String[] { "excluded_field" }); + // end::put-data-frame-analytics-analyzed-fields + + // tag::put-data-frame-analytics-config + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.builder("my-analytics-config") // <1> + .setSource(sourceConfig) // <2> + .setDest(destConfig) // <3> + .setAnalysis(outlierDetection) // <4> + .setAnalyzedFields(analyzedFields) // <5> + .setModelMemoryLimit(new ByteSizeValue(5, ByteSizeUnit.MB)) // <6> + .build(); + // end::put-data-frame-analytics-config + + // tag::put-data-frame-analytics-request + PutDataFrameAnalyticsRequest request = new PutDataFrameAnalyticsRequest(config); // <1> + // end::put-data-frame-analytics-request + + // tag::put-data-frame-analytics-execute + PutDataFrameAnalyticsResponse response = client.machineLearning().putDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::put-data-frame-analytics-execute + + // tag::put-data-frame-analytics-response + DataFrameAnalyticsConfig createdConfig = response.getConfig(); + // end::put-data-frame-analytics-response + + assertThat(createdConfig.getId(), equalTo("my-analytics-config")); + } + { + PutDataFrameAnalyticsRequest request = new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG); + // tag::put-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(PutDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::put-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + final CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::put-data-frame-analytics-execute-async + client.machineLearning().putDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::put-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testDeleteDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::delete-data-frame-analytics-request + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::delete-data-frame-analytics-request + + // tag::delete-data-frame-analytics-execute + AcknowledgedResponse response = client.machineLearning().deleteDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::delete-data-frame-analytics-execute + + // tag::delete-data-frame-analytics-response + boolean acknowledged = response.isAcknowledged(); + // end::delete-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + DeleteDataFrameAnalyticsRequest request = new DeleteDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::delete-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(AcknowledgedResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::delete-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::delete-data-frame-analytics-execute-async + client.machineLearning().deleteDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::delete-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } + + public void testStartDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + highLevelClient().index( + new IndexRequest(DF_ANALYTICS_CONFIG.getSource().getIndex()).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::start-data-frame-analytics-request + StartDataFrameAnalyticsRequest request = new StartDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::start-data-frame-analytics-request + + // tag::start-data-frame-analytics-execute + AcknowledgedResponse response = client.machineLearning().startDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::start-data-frame-analytics-execute + + // tag::start-data-frame-analytics-response + boolean acknowledged = response.isAcknowledged(); + // end::start-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + { + StartDataFrameAnalyticsRequest request = new StartDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::start-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(AcknowledgedResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::start-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::start-data-frame-analytics-execute-async + client.machineLearning().startDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::start-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + } + + public void testStopDataFrameAnalytics() throws Exception { + createIndex(DF_ANALYTICS_CONFIG.getSource().getIndex()); + highLevelClient().index( + new IndexRequest(DF_ANALYTICS_CONFIG.getSource().getIndex()).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT); + RestHighLevelClient client = highLevelClient(); + client.machineLearning().putDataFrameAnalytics(new PutDataFrameAnalyticsRequest(DF_ANALYTICS_CONFIG), RequestOptions.DEFAULT); + { + // tag::stop-data-frame-analytics-request + StopDataFrameAnalyticsRequest request = new StopDataFrameAnalyticsRequest("my-analytics-config"); // <1> + // end::stop-data-frame-analytics-request + + // tag::stop-data-frame-analytics-execute + StopDataFrameAnalyticsResponse response = client.machineLearning().stopDataFrameAnalytics(request, RequestOptions.DEFAULT); + // end::stop-data-frame-analytics-execute + + // tag::stop-data-frame-analytics-response + boolean acknowledged = response.isStopped(); + // end::stop-data-frame-analytics-response + + assertThat(acknowledged, is(true)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + { + StopDataFrameAnalyticsRequest request = new StopDataFrameAnalyticsRequest("my-analytics-config"); + + // tag::stop-data-frame-analytics-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(StopDataFrameAnalyticsResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::stop-data-frame-analytics-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::stop-data-frame-analytics-execute-async + client.machineLearning().stopDataFrameAnalyticsAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::stop-data-frame-analytics-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + assertBusy( + () -> assertThat(getAnalyticsState(DF_ANALYTICS_CONFIG.getId()), equalTo(DataFrameAnalyticsState.STOPPED)), + 30, TimeUnit.SECONDS); + } + + public void testEvaluateDataFrame() throws Exception { + String indexName = "evaluate-test-index"; + CreateIndexRequest createIndexRequest = + new CreateIndexRequest(indexName) + .mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("label") + .field("type", "keyword") + .endObject() + .startObject("p") + .field("type", "double") + .endObject() + .endObject() + .endObject()); + BulkRequest bulkRequest = + new BulkRequest(indexName) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.1)) // #0 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.2)) // #1 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.3)) // #2 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.4)) // #3 + .add(new IndexRequest().source(XContentType.JSON, "label", false, "p", 0.7)) // #4 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.2)) // #5 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.3)) // #6 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.4)) // #7 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.8)) // #8 + .add(new IndexRequest().source(XContentType.JSON, "label", true, "p", 0.9)); // #9 + RestHighLevelClient client = highLevelClient(); + client.indices().create(createIndexRequest, RequestOptions.DEFAULT); + client.bulk(bulkRequest, RequestOptions.DEFAULT); + { + // tag::evaluate-data-frame-request + EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( // <1> + indexName, // <2> + new BinarySoftClassification( // <3> + "label", // <4> + "p", // <5> + // Evaluation metrics // <6> + PrecisionMetric.at(0.4, 0.5, 0.6), // <7> + RecallMetric.at(0.5, 0.7), // <8> + ConfusionMatrixMetric.at(0.5), // <9> + AucRocMetric.withCurve())); // <10> + // end::evaluate-data-frame-request + + // tag::evaluate-data-frame-execute + EvaluateDataFrameResponse response = client.machineLearning().evaluateDataFrame(request, RequestOptions.DEFAULT); + // end::evaluate-data-frame-execute + + // tag::evaluate-data-frame-response + List metrics = response.getMetrics(); // <1> + + PrecisionMetric.Result precisionResult = response.getMetricByName(PrecisionMetric.NAME); // <2> + double precision = precisionResult.getScoreByThreshold("0.4"); // <3> + + ConfusionMatrixMetric.Result confusionMatrixResult = response.getMetricByName(ConfusionMatrixMetric.NAME); // <4> + ConfusionMatrix confusionMatrix = confusionMatrixResult.getScoreByThreshold("0.5"); // <5> + // end::evaluate-data-frame-response + + assertThat( + metrics.stream().map(m -> m.getMetricName()).collect(Collectors.toList()), + containsInAnyOrder(PrecisionMetric.NAME, RecallMetric.NAME, ConfusionMatrixMetric.NAME, AucRocMetric.NAME)); + assertThat(precision, closeTo(0.6, 1e-9)); + assertThat(confusionMatrix.getTruePositives(), CoreMatchers.equalTo(2L)); // docs #8 and #9 + assertThat(confusionMatrix.getFalsePositives(), CoreMatchers.equalTo(1L)); // doc #4 + assertThat(confusionMatrix.getTrueNegatives(), CoreMatchers.equalTo(4L)); // docs #0, #1, #2 and #3 + assertThat(confusionMatrix.getFalseNegatives(), CoreMatchers.equalTo(3L)); // docs #5, #6 and #7 + } + { + EvaluateDataFrameRequest request = new EvaluateDataFrameRequest( + indexName, + new BinarySoftClassification( + "label", + "p", + PrecisionMetric.at(0.4, 0.5, 0.6), + RecallMetric.at(0.5, 0.7), + ConfusionMatrixMetric.at(0.5), + AucRocMetric.withCurve())); + + // tag::evaluate-data-frame-execute-listener + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(EvaluateDataFrameResponse response) { + // <1> + } + + @Override + public void onFailure(Exception e) { + // <2> + } + }; + // end::evaluate-data-frame-execute-listener + + // Replace the empty listener by a blocking listener in test + CountDownLatch latch = new CountDownLatch(1); + listener = new LatchedActionListener<>(listener, latch); + + // tag::evaluate-data-frame-execute-async + client.machineLearning().evaluateDataFrameAsync(request, RequestOptions.DEFAULT, listener); // <1> + // end::evaluate-data-frame-execute-async + + assertTrue(latch.await(30L, TimeUnit.SECONDS)); + } + } public void testCreateFilter() throws Exception { RestHighLevelClient client = highLevelClient(); @@ -3140,4 +3596,39 @@ private String createFilter(RestHighLevelClient client) throws IOException { assertThat(createdFilter.getId(), equalTo("my_safe_domains")); return createdFilter.getId(); } + + private void createIndex(String indexName) throws IOException { + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + createIndexRequest.mapping(XContentFactory.jsonBuilder().startObject() + .startObject("properties") + .startObject("timestamp") + .field("type", "date") + .endObject() + .startObject("total") + .field("type", "long") + .endObject() + .endObject() + .endObject()); + highLevelClient().indices().create(createIndexRequest, RequestOptions.DEFAULT); + } + + private DataFrameAnalyticsState getAnalyticsState(String configId) throws IOException { + GetDataFrameAnalyticsStatsResponse statsResponse = + highLevelClient().machineLearning().getDataFrameAnalyticsStats( + new GetDataFrameAnalyticsStatsRequest(configId), RequestOptions.DEFAULT); + assertThat(statsResponse.getAnalyticsStats(), hasSize(1)); + DataFrameAnalyticsStats stats = statsResponse.getAnalyticsStats().get(0); + return stats.getState(); + } + + private static final DataFrameAnalyticsConfig DF_ANALYTICS_CONFIG = + DataFrameAnalyticsConfig.builder("my-analytics-config") + .setSource(DataFrameAnalyticsSource.builder() + .setIndex("put-test-source-index") + .build()) + .setDest(DataFrameAnalyticsDest.builder() + .setIndex("put-test-dest-index") + .build()) + .setAnalysis(OutlierDetection.createDefault()) + .build(); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java new file mode 100644 index 0000000000000..825adcd2060f8 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricAucRocPointTests.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class AucRocMetricAucRocPointTests extends AbstractXContentTestCase { + + static AucRocMetric.AucRocPoint randomPoint() { + return new AucRocMetric.AucRocPoint(randomDouble(), randomDouble(), randomDouble()); + } + + @Override + protected AucRocMetric.AucRocPoint createTestInstance() { + return randomPoint(); + } + + @Override + protected AucRocMetric.AucRocPoint doParseInstance(XContentParser parser) throws IOException { + return AucRocMetric.AucRocPoint.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java new file mode 100644 index 0000000000000..9ea7689d60f32 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/AucRocMetricResultTests.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.AucRocMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.client.ml.AucRocMetricAucRocPointTests.randomPoint; + +public class AucRocMetricResultTests extends AbstractXContentTestCase { + + static AucRocMetric.Result randomResult() { + return new AucRocMetric.Result( + randomDouble(), + Stream + .generate(() -> randomPoint()) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + } + + @Override + protected AucRocMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected AucRocMetric.Result doParseInstance(XContentParser parser) throws IOException { + return AucRocMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java new file mode 100644 index 0000000000000..28eb221b318c6 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricConfusionMatrixTests.java @@ -0,0 +1,47 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class ConfusionMatrixMetricConfusionMatrixTests extends AbstractXContentTestCase { + + static ConfusionMatrixMetric.ConfusionMatrix randomConfusionMatrix() { + return new ConfusionMatrixMetric.ConfusionMatrix(randomInt(), randomInt(), randomInt(), randomInt()); + } + + @Override + protected ConfusionMatrixMetric.ConfusionMatrix createTestInstance() { + return randomConfusionMatrix(); + } + + @Override + protected ConfusionMatrixMetric.ConfusionMatrix doParseInstance(XContentParser parser) throws IOException { + return ConfusionMatrixMetric.ConfusionMatrix.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java new file mode 100644 index 0000000000000..c4b299a96b536 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/ConfusionMatrixMetricResultTests.java @@ -0,0 +1,62 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.ConfusionMatrixMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.elasticsearch.client.ml.ConfusionMatrixMetricConfusionMatrixTests.randomConfusionMatrix; + +public class ConfusionMatrixMetricResultTests extends AbstractXContentTestCase { + + static ConfusionMatrixMetric.Result randomResult() { + return new ConfusionMatrixMetric.Result( + Stream + .generate(() -> randomConfusionMatrix()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected ConfusionMatrixMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected ConfusionMatrixMetric.Result doParseInstance(XContentParser parser) throws IOException { + return ConfusionMatrixMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..bc2ca2d954e76 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/DeleteDataFrameAnalyticsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class DeleteDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new DeleteDataFrameAnalyticsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new DeleteDataFrameAnalyticsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new DeleteDataFrameAnalyticsRequest(null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java new file mode 100644 index 0000000000000..b41d113686ccf --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/EvaluateDataFrameResponseTests.java @@ -0,0 +1,76 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric; +import org.elasticsearch.client.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Predicate; + +public class EvaluateDataFrameResponseTests extends AbstractXContentTestCase { + + public static EvaluateDataFrameResponse randomResponse() { + List metrics = new ArrayList<>(); + if (randomBoolean()) { + metrics.add(AucRocMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(PrecisionMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(RecallMetricResultTests.randomResult()); + } + if (randomBoolean()) { + metrics.add(ConfusionMatrixMetricResultTests.randomResult()); + } + return new EvaluateDataFrameResponse(randomAlphaOfLength(5), metrics); + } + + @Override + protected EvaluateDataFrameResponse createTestInstance() { + return randomResponse(); + } + + @Override + protected EvaluateDataFrameResponse doParseInstance(XContentParser parser) throws IOException { + return EvaluateDataFrameResponse.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the metrics map (i.e. alongside named metrics like "precision" or "recall") + return field -> field.isEmpty() || field.contains("."); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..56d87ea6bef49 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class GetDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new GetDataFrameAnalyticsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new GetDataFrameAnalyticsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new GetDataFrameAnalyticsRequest(new String[0]).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java new file mode 100644 index 0000000000000..4e08d99eaa659 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/GetDataFrameAnalyticsStatsRequestTests.java @@ -0,0 +1,39 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class GetDataFrameAnalyticsStatsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new GetDataFrameAnalyticsStatsRequest("valid-id").validate()); + assertEquals(Optional.empty(), new GetDataFrameAnalyticsStatsRequest("").validate()); + } + + public void testValidate_Failure() { + assertThat(new GetDataFrameAnalyticsStatsRequest(new String[0]).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java new file mode 100644 index 0000000000000..607adacebb827 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PrecisionMetricResultTests.java @@ -0,0 +1,60 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.PrecisionMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class PrecisionMetricResultTests extends AbstractXContentTestCase { + + static PrecisionMetric.Result randomResult() { + return new PrecisionMetric.Result( + Stream + .generate(() -> randomDouble()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected PrecisionMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected PrecisionMetric.Result doParseInstance(XContentParser parser) throws IOException { + return PrecisionMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..7387ba8ddeb65 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/PutDataFrameAnalyticsRequestTests.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ValidationException; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.client.ml.dataframe.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class PutDataFrameAnalyticsRequestTests extends AbstractXContentTestCase { + + public void testValidate_Ok() { + assertFalse(createTestInstance().validate().isPresent()); + } + + public void testValidate_Failure() { + Optional exception = new PutDataFrameAnalyticsRequest(null).validate(); + assertTrue(exception.isPresent()); + assertThat(exception.get().getMessage(), containsString("put requires a non-null data frame analytics config")); + } + + @Override + protected PutDataFrameAnalyticsRequest createTestInstance() { + return new PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfigTests.randomDataFrameAnalyticsConfig()); + } + + @Override + protected PutDataFrameAnalyticsRequest doParseInstance(XContentParser parser) throws IOException { + return new PutDataFrameAnalyticsRequest(DataFrameAnalyticsConfig.fromXContent(parser)); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java new file mode 100644 index 0000000000000..138875007e30d --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/RecallMetricResultTests.java @@ -0,0 +1,60 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.RecallMetric; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class RecallMetricResultTests extends AbstractXContentTestCase { + + static RecallMetric.Result randomResult() { + return new RecallMetric.Result( + Stream + .generate(() -> randomDouble()) + .limit(randomIntBetween(1, 5)) + .collect(Collectors.toMap(v -> String.valueOf(randomDouble()), v -> v))); + } + + @Override + protected RecallMetric.Result createTestInstance() { + return randomResult(); + } + + @Override + protected RecallMetric.Result doParseInstance(XContentParser parser) throws IOException { + return RecallMetric.Result.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // disallow unknown fields in the root of the object as field names must be parsable as numbers + return field -> field.isEmpty(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..6e43b50bcd12b --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StartDataFrameAnalyticsRequestTests.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class StartDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").validate()); + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").setTimeout(null).validate()); + assertEquals(Optional.empty(), new StartDataFrameAnalyticsRequest("foo").setTimeout(TimeValue.ZERO).validate()); + } + + public void testValidate_Failure() { + assertThat(new StartDataFrameAnalyticsRequest(null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + assertThat(new StartDataFrameAnalyticsRequest(null).setTimeout(TimeValue.ZERO).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..57af2083743ae --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsRequestTests.java @@ -0,0 +1,43 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.ESTestCase; + +import java.util.Optional; + +import static org.hamcrest.Matchers.containsString; + +public class StopDataFrameAnalyticsRequestTests extends ESTestCase { + + public void testValidate_Ok() { + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").validate()); + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").setTimeout(null).validate()); + assertEquals(Optional.empty(), new StopDataFrameAnalyticsRequest("foo").setTimeout(TimeValue.ZERO).validate()); + } + + public void testValidate_Failure() { + assertThat(new StopDataFrameAnalyticsRequest(null).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + assertThat(new StopDataFrameAnalyticsRequest(null).setTimeout(TimeValue.ZERO).validate().get().getMessage(), + containsString("data frame analytics id must not be null")); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java new file mode 100644 index 0000000000000..55ef1aed7534a --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/StopDataFrameAnalyticsResponseTests.java @@ -0,0 +1,42 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class StopDataFrameAnalyticsResponseTests extends AbstractXContentTestCase { + + @Override + protected StopDataFrameAnalyticsResponse createTestInstance() { + return new StopDataFrameAnalyticsResponse(randomBoolean()); + } + + @Override + protected StopDataFrameAnalyticsResponse doParseInstance(XContentParser parser) throws IOException { + return StopDataFrameAnalyticsResponse.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java new file mode 100644 index 0000000000000..f6826af551d0a --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -0,0 +1,88 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Predicate; + +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsSourceTests.randomSourceConfig; +import static org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDestTests.randomDestConfig; +import static org.elasticsearch.client.ml.dataframe.OutlierDetectionTests.randomOutlierDetection; + +public class DataFrameAnalyticsConfigTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsConfig randomDataFrameAnalyticsConfig() { + DataFrameAnalyticsConfig.Builder builder = + DataFrameAnalyticsConfig.builder(randomAlphaOfLengthBetween(1, 10)) + .setSource(randomSourceConfig()) + .setDest(randomDestConfig()) + .setAnalysis(randomOutlierDetection()); + if (randomBoolean()) { + builder.setAnalyzedFields(new FetchSourceContext(true, + generateRandomStringArray(10, 10, false, false), + generateRandomStringArray(10, 10, false, false))); + } + if (randomBoolean()) { + builder.setModelMemoryLimit(new ByteSizeValue(randomIntBetween(1, 16), randomFrom(ByteSizeUnit.MB, ByteSizeUnit.GB))); + } + return builder.build(); + } + + @Override + protected DataFrameAnalyticsConfig createTestInstance() { + return randomDataFrameAnalyticsConfig(); + } + + @Override + protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only + return field -> !field.isEmpty(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + return new NamedXContentRegistry(namedXContent); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java new file mode 100644 index 0000000000000..dce7ca5204d57 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -0,0 +1,50 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +public class DataFrameAnalyticsDestTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsDest randomDestConfig() { + return DataFrameAnalyticsDest.builder() + .setIndex(randomAlphaOfLengthBetween(1, 10)) + .setResultsField(randomBoolean() ? null : randomAlphaOfLengthBetween(1, 10)) + .build(); + } + + @Override + protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsDest.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected DataFrameAnalyticsDest createTestInstance() { + return randomDestConfig(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java new file mode 100644 index 0000000000000..246cd67c1baf1 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsSourceTests.java @@ -0,0 +1,70 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; + +import static java.util.Collections.emptyList; +import static org.elasticsearch.client.ml.dataframe.QueryConfigTests.randomQueryConfig; + + +public class DataFrameAnalyticsSourceTests extends AbstractXContentTestCase { + + public static DataFrameAnalyticsSource randomSourceConfig() { + return DataFrameAnalyticsSource.builder() + .setIndex(randomAlphaOfLengthBetween(1, 10)) + .setQueryConfig(randomBoolean() ? null : randomQueryConfig()) + .build(); + } + + @Override + protected DataFrameAnalyticsSource doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsSource.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // allow unknown fields in the root of the object only as QueryConfig stores a Map + return field -> !field.isEmpty(); + } + + @Override + protected DataFrameAnalyticsSource createTestInstance() { + return randomSourceConfig(); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java new file mode 100644 index 0000000000000..ed6e24f754d19 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java @@ -0,0 +1,66 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.client.ml.NodeAttributesTests; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester; + +public class DataFrameAnalyticsStatsTests extends ESTestCase { + + public void testFromXContent() throws IOException { + xContentTester(this::createParser, + DataFrameAnalyticsStatsTests::randomDataFrameAnalyticsStats, + DataFrameAnalyticsStatsTests::toXContent, + DataFrameAnalyticsStats::fromXContent) + .supportsUnknownFields(true) + .randomFieldsExcludeFilter(field -> field.startsWith("node.attributes")) + .test(); + } + + public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() { + return new DataFrameAnalyticsStats( + randomAlphaOfLengthBetween(1, 10), + randomFrom(DataFrameAnalyticsState.values()), + randomBoolean() ? null : randomIntBetween(0, 100), + randomBoolean() ? null : NodeAttributesTests.createRandom(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 20)); + } + + public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder builder) throws IOException { + builder.startObject(); + builder.field(DataFrameAnalyticsStats.ID.getPreferredName(), stats.getId()); + builder.field(DataFrameAnalyticsStats.STATE.getPreferredName(), stats.getState().value()); + if (stats.getProgressPercent() != null) { + builder.field(DataFrameAnalyticsStats.PROGRESS_PERCENT.getPreferredName(), stats.getProgressPercent()); + } + if (stats.getNode() != null) { + builder.field(DataFrameAnalyticsStats.NODE.getPreferredName(), stats.getNode()); + } + if (stats.getAssignmentExplanation() != null) { + builder.field(DataFrameAnalyticsStats.ASSIGNMENT_EXPLANATION.getPreferredName(), stats.getAssignmentExplanation()); + } + builder.endObject(); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java new file mode 100644 index 0000000000000..de110d92fdee1 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/OutlierDetectionTests.java @@ -0,0 +1,73 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class OutlierDetectionTests extends AbstractXContentTestCase { + + public static OutlierDetection randomOutlierDetection() { + return OutlierDetection.builder() + .setNNeighbors(randomBoolean() ? null : randomIntBetween(1, 20)) + .setMethod(randomBoolean() ? null : randomFrom(OutlierDetection.Method.values())) + .setMinScoreToWriteFeatureInfluence(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, true)) + .build(); + } + + @Override + protected OutlierDetection doParseInstance(XContentParser parser) throws IOException { + return OutlierDetection.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected OutlierDetection createTestInstance() { + return randomOutlierDetection(); + } + + public void testGetParams_GivenDefaults() { + OutlierDetection outlierDetection = OutlierDetection.createDefault(); + assertNull(outlierDetection.getNNeighbors()); + assertNull(outlierDetection.getMethod()); + assertNull(outlierDetection.getMinScoreToWriteFeatureInfluence()); + } + + public void testGetParams_GivenExplicitValues() { + OutlierDetection outlierDetection = + OutlierDetection.builder() + .setNNeighbors(42) + .setMethod(OutlierDetection.Method.LDOF) + .setMinScoreToWriteFeatureInfluence(0.5) + .build(); + assertThat(outlierDetection.getNNeighbors(), equalTo(42)); + assertThat(outlierDetection.getMethod(), equalTo(OutlierDetection.Method.LDOF)); + assertThat(outlierDetection.getMinScoreToWriteFeatureInfluence(), closeTo(0.5, 1E-9)); + } +} diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java new file mode 100644 index 0000000000000..7413bc936a215 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/QueryConfigTests.java @@ -0,0 +1,62 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.client.ml.dataframe; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; + +import static java.util.Collections.emptyList; + +public class QueryConfigTests extends AbstractXContentTestCase { + + public static QueryConfig randomQueryConfig() { + QueryBuilder queryBuilder = randomBoolean() ? new MatchAllQueryBuilder() : new MatchNoneQueryBuilder(); + return new QueryConfig(queryBuilder); + } + + @Override + protected QueryConfig createTestInstance() { + return randomQueryConfig(); + } + + @Override + protected QueryConfig doParseInstance(XContentParser parser) throws IOException { + return QueryConfig.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + SearchModule searchModule = new SearchModule(Settings.EMPTY, emptyList()); + return new NamedXContentRegistry(searchModule.getNamedXContents()); + } +} diff --git a/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..2e5ade37107cf --- /dev/null +++ b/docs/java-rest/high-level/ml/delete-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: delete-data-frame-analytics +:request: DeleteDataFrameAnalyticsRequest +:response: AcknowledgedResponse +-- +[id="{upid}-{api}"] +=== Delete Data Frame Analytics API + +The Delete Data Frame Analytics API is used to delete an existing {dataframe-analytics-config}. +The API accepts a +{request}+ object as a request and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Delete Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-analytics-config} deletion. diff --git a/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc new file mode 100644 index 0000000000000..660603d2e38e7 --- /dev/null +++ b/docs/java-rest/high-level/ml/evaluate-data-frame.asciidoc @@ -0,0 +1,45 @@ +-- +:api: evaluate-data-frame +:request: EvaluateDataFrameRequest +:response: EvaluateDataFrameResponse +-- +[id="{upid}-{api}"] +=== Evaluate Data Frame API + +The Evaluate Data Frame API is used to evaluate an ML algorithm that ran on a {dataframe}. +The API accepts an +{request}+ object and returns an +{response}+. + +[id="{upid}-{api}-request"] +==== Evaluate Data Frame Request + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new evaluation request +<2> Reference to an existing index +<3> Kind of evaluation to perform +<4> Name of the field in the index. Its value denotes the actual (i.e. ground truth) label for an example. Must be either true or false +<5> Name of the field in the index. Its value denotes the probability (as per some ML algorithm) of the example being classified as positive +<6> The remaining parameters are the metrics to be calculated based on the two fields described above. +<7> https://en.wikipedia.org/wiki/Precision_and_recall[Precision] calculated at thresholds: 0.4, 0.5 and 0.6 +<8> https://en.wikipedia.org/wiki/Precision_and_recall[Recall] calculated at thresholds: 0.5 and 0.7 +<9> https://en.wikipedia.org/wiki/Confusion_matrix[Confusion matrix] calculated at threshold 0.5 +<10> https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve[AuC ROC] calculated and the curve points returned + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested evaluation metrics. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- +<1> Fetching all the calculated metrics results +<2> Fetching precision metric by name +<3> Fetching precision at a given (0.4) threshold +<4> Fetching confusion matrix metric by name +<5> Fetching confusion matrix at a given (0.5) threshold \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc b/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc new file mode 100644 index 0000000000000..e1047e9b3e002 --- /dev/null +++ b/docs/java-rest/high-level/ml/get-data-frame-analytics-stats.asciidoc @@ -0,0 +1,34 @@ +-- +:api: get-data-frame-analytics-stats +:request: GetDataFrameAnalyticsStatsRequest +:response: GetDataFrameAnalyticsStatsResponse +-- +[id="{upid}-{api}"] +=== Get Data Frame Analytics Stats API + +The Get Data Frame Analytics Stats API is used to read the operational statistics of one or more {dataframe-analytics-config}s. +The API accepts a +{request}+ object and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Get Data Frame Analytics Stats Request + +A +{request}+ requires either a {dataframe-analytics-config} id, a comma separated list of ids or +the special wildcard `_all` to get the statistics for all {dataframe-analytics-config}s + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new GET Stats request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested {dataframe-analytics-config} statistics. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..c6d368efbcae9 --- /dev/null +++ b/docs/java-rest/high-level/ml/get-data-frame-analytics.asciidoc @@ -0,0 +1,34 @@ +-- +:api: get-data-frame-analytics +:request: GetDataFrameAnalyticsRequest +:response: GetDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Get Data Frame Analytics API + +The Get Data Frame Analytics API is used to get one or more {dataframe-analytics-config}s. +The API accepts a +{request}+ object and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Get Data Frame Analytics Request + +A +{request}+ requires either a {dataframe-analytics-config} id, a comma separated list of ids or +the special wildcard `_all` to get all {dataframe-analytics-config}s. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> Constructing a new GET request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the requested {dataframe-analytics-config}s. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..05fbd5bc3922a --- /dev/null +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -0,0 +1,115 @@ +-- +:api: put-data-frame-analytics +:request: PutDataFrameAnalyticsRequest +:response: PutDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Put Data Frame Analytics API + +The Put Data Frame Analytics API is used to create a new {dataframe-analytics-config}. +The API accepts a +{request}+ object as a request and returns a +{response}+. + +[id="{upid}-{api}-request"] +==== Put Data Frame Analytics Request + +A +{request}+ requires the following argument: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +-------------------------------------------------- +<1> The configuration of the {dataframe-job} to create + +[id="{upid}-{api}-config"] +==== Data Frame Analytics Configuration + +The `DataFrameAnalyticsConfig` object contains all the details about the {dataframe-job} +configuration and contains the following arguments: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-config] +-------------------------------------------------- +<1> The {dataframe-analytics-config} id +<2> The source index and query from which to gather data +<3> The destination index +<4> The analysis to be performed +<5> The fields to be included in / excluded from the analysis +<6> The memory limit for the model created as part of the analysis process + +[id="{upid}-{api}-query-config"] + +==== SourceConfig + +The index and the query from which to collect data. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-source-config] +-------------------------------------------------- +<1> Constructing a new DataFrameAnalyticsSource +<2> The source index +<3> The query from which to gather the data. If query is not set, a `match_all` query is used by default. + +===== QueryConfig + +The query with which to select data from the source. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-query-config] +-------------------------------------------------- + +==== DestinationConfig + +The index to which data should be written by the {dataframe-job}. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-dest-config] +-------------------------------------------------- +<1> Constructing a new DataFrameAnalyticsDest +<2> The destination index + +==== Analysis + +The analysis to be performed. +Currently, only one analysis is supported: +OutlierDetection+. + ++OutlierDetection+ analysis can be created in one of two ways: + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analysis-default] +-------------------------------------------------- +<1> Constructing a new OutlierDetection object with default strategy to determine outliers + +or +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analysis-customized] +-------------------------------------------------- +<1> Constructing a new OutlierDetection object +<2> The method used to perform the analysis +<3> Number of neighbors taken into account during analysis + +==== Analyzed fields + +FetchContext object containing fields to be included in / excluded from the analysis + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-analyzed-fields] +-------------------------------------------------- + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ contains the newly created {dataframe-analytics-config}. + +["source","java",subs="attributes,callouts,macros"] +-------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-response] +-------------------------------------------------- \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..610607daba1f8 --- /dev/null +++ b/docs/java-rest/high-level/ml/start-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: start-data-frame-analytics +:request: StartDataFrameAnalyticsRequest +:response: AcknowledgedResponse +-- +[id="{upid}-{api}"] +=== Start Data Frame Analytics API + +The Start Data Frame Analytics API is used to start an existing {dataframe-analytics-config}. +It accepts a +{request}+ object and responds with a +{response}+ object. + +[id="{upid}-{api}-request"] +==== Start Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new start request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-job} has started. \ No newline at end of file diff --git a/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc new file mode 100644 index 0000000000000..243c075e18b03 --- /dev/null +++ b/docs/java-rest/high-level/ml/stop-data-frame-analytics.asciidoc @@ -0,0 +1,28 @@ +-- +:api: stop-data-frame-analytics +:request: StopDataFrameAnalyticsRequest +:response: StopDataFrameAnalyticsResponse +-- +[id="{upid}-{api}"] +=== Stop Data Frame Analytics API + +The Stop Data Frame Analytics API is used to stop a running {dataframe-analytics-config}. +It accepts a +{request}+ object and responds with a +{response}+ object. + +[id="{upid}-{api}-request"] +==== Stop Data Frame Analytics Request + +A +{request}+ object requires a {dataframe-analytics-config} id. + +["source","java",subs="attributes,callouts,macros"] +--------------------------------------------------- +include-tagged::{doc-tests-file}[{api}-request] +--------------------------------------------------- +<1> Constructing a new stop request referencing an existing {dataframe-analytics-config} + +include::../execution.asciidoc[] + +[id="{upid}-{api}-response"] +==== Response + +The returned +{response}+ object acknowledges the {dataframe-job} has stopped. \ No newline at end of file diff --git a/docs/java-rest/high-level/supported-apis.asciidoc b/docs/java-rest/high-level/supported-apis.asciidoc index 4e28efc2941db..21ebdfab65155 100644 --- a/docs/java-rest/high-level/supported-apis.asciidoc +++ b/docs/java-rest/high-level/supported-apis.asciidoc @@ -285,6 +285,13 @@ The Java High Level REST Client supports the following Machine Learning APIs: * <<{upid}-put-calendar-job>> * <<{upid}-delete-calendar-job>> * <<{upid}-delete-calendar>> +* <<{upid}-get-data-frame-analytics>> +* <<{upid}-get-data-frame-analytics-stats>> +* <<{upid}-put-data-frame-analytics>> +* <<{upid}-delete-data-frame-analytics>> +* <<{upid}-start-data-frame-analytics>> +* <<{upid}-stop-data-frame-analytics>> +* <<{upid}-evaluate-data-frame>> * <<{upid}-put-filter>> * <<{upid}-get-filters>> * <<{upid}-update-filter>> @@ -329,6 +336,13 @@ include::ml/delete-calendar-event.asciidoc[] include::ml/put-calendar-job.asciidoc[] include::ml/delete-calendar-job.asciidoc[] include::ml/delete-calendar.asciidoc[] +include::ml/get-data-frame-analytics.asciidoc[] +include::ml/get-data-frame-analytics-stats.asciidoc[] +include::ml/put-data-frame-analytics.asciidoc[] +include::ml/delete-data-frame-analytics.asciidoc[] +include::ml/start-data-frame-analytics.asciidoc[] +include::ml/stop-data-frame-analytics.asciidoc[] +include::ml/evaluate-data-frame.asciidoc[] include::ml/put-filter.asciidoc[] include::ml/get-filters.asciidoc[] include::ml/update-filter.asciidoc[] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index eb05b3f013b81..e2037d4e260a9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -48,6 +48,8 @@ import org.elasticsearch.xpack.core.dataframe.action.StopDataFrameTransformAction; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransform; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformState; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.elasticsearch.xpack.core.deprecation.DeprecationInfoAction; import org.elasticsearch.xpack.core.graph.GraphFeatureSetUsage; import org.elasticsearch.xpack.core.graph.action.GraphExploreAction; @@ -79,12 +81,14 @@ import org.elasticsearch.xpack.core.ml.action.CloseJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction; import org.elasticsearch.xpack.core.ml.action.DeleteCalendarEventAction; +import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.DeleteDatafeedAction; import org.elasticsearch.xpack.core.ml.action.DeleteExpiredDataAction; import org.elasticsearch.xpack.core.ml.action.DeleteFilterAction; import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction; import org.elasticsearch.xpack.core.ml.action.DeleteJobAction; import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; import org.elasticsearch.xpack.core.ml.action.FindFileStructureAction; import org.elasticsearch.xpack.core.ml.action.FlushJobAction; @@ -93,6 +97,8 @@ import org.elasticsearch.xpack.core.ml.action.GetCalendarEventsAction; import org.elasticsearch.xpack.core.ml.action.GetCalendarsAction; import org.elasticsearch.xpack.core.ml.action.GetCategoriesAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsAction; import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction; import org.elasticsearch.xpack.core.ml.action.GetFiltersAction; @@ -111,11 +117,13 @@ import org.elasticsearch.xpack.core.ml.action.PostDataAction; import org.elasticsearch.xpack.core.ml.action.PreviewDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutCalendarAction; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.PutDatafeedAction; import org.elasticsearch.xpack.core.ml.action.PutFilterAction; import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.RevertModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.SetUpgradeModeAction; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; import org.elasticsearch.xpack.core.ml.action.UpdateCalendarJobAction; @@ -127,6 +135,18 @@ import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; import org.elasticsearch.xpack.core.monitoring.MonitoringFeatureSetUsage; import org.elasticsearch.xpack.core.rollup.RollupFeatureSetUsage; @@ -282,6 +302,12 @@ public List> getClientActions() { PersistJobAction.INSTANCE, FindFileStructureAction.INSTANCE, SetUpgradeModeAction.INSTANCE, + PutDataFrameAnalyticsAction.INSTANCE, + GetDataFrameAnalyticsAction.INSTANCE, + GetDataFrameAnalyticsStatsAction.INSTANCE, + DeleteDataFrameAnalyticsAction.INSTANCE, + StartDataFrameAnalyticsAction.INSTANCE, + EvaluateDataFrameAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, @@ -374,11 +400,30 @@ public List getNamedWriteables() { StartDatafeedAction.DatafeedParams::new), new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.JOB_TASK_NAME, OpenJobAction.JobParams::new), + new NamedWriteableRegistry.Entry(PersistentTaskParams.class, MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, + StartDataFrameAnalyticsAction.TaskParams::new), // ML - Task states new NamedWriteableRegistry.Entry(PersistentTaskState.class, JobTaskState.NAME, JobTaskState::new), new NamedWriteableRegistry.Entry(PersistentTaskState.class, DatafeedState.NAME, DatafeedState::fromStream), + new NamedWriteableRegistry.Entry(PersistentTaskState.class, DataFrameAnalyticsTaskState.NAME, + DataFrameAnalyticsTaskState::new), new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MACHINE_LEARNING, MachineLearningFeatureSetUsage::new), + // ML - Data frame analytics + new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), OutlierDetection::new), + // ML - Data frame evaluation + new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), + BinarySoftClassification::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), AucRoc::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), Precision::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), Recall::new), + new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), AucRoc.Result::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, ScoreByThresholdResult::new), + new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix.Result::new), + // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), // security @@ -439,6 +484,7 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(PersistentTaskParams.class, DataFrameField.TASK_NAME, DataFrameTransform::new), new NamedWriteableRegistry.Entry(Task.Status.class, DataFrameField.TASK_NAME, DataFrameTransformState::new), new NamedWriteableRegistry.Entry(PersistentTaskState.class, DataFrameField.TASK_NAME, DataFrameTransformState::new), + new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), TimeSyncConfig::new), // Vectors new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.VECTORS, VectorsFeatureSetUsage::new) ); @@ -455,9 +501,13 @@ public List getNamedXContent() { StartDatafeedAction.DatafeedParams::fromXContent), new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(MlTasks.JOB_TASK_NAME), OpenJobAction.JobParams::fromXContent), + new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME), + StartDataFrameAnalyticsAction.TaskParams::fromXContent), // ML - Task states new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(DatafeedState.NAME), DatafeedState::fromXContent), new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(JobTaskState.NAME), JobTaskState::fromXContent), + new NamedXContentRegistry.Entry(PersistentTaskState.class, new ParseField(DataFrameAnalyticsTaskState.NAME), + DataFrameAnalyticsTaskState::fromXContent), // watcher new NamedXContentRegistry.Entry(MetaData.Custom.class, new ParseField(WatcherMetaData.TYPE), WatcherMetaData::fromXContent), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java index c61ed2ddde8be..71878c4894d6a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameField.java @@ -28,6 +28,10 @@ public final class DataFrameField { public static final ParseField DESTINATION = new ParseField("dest"); public static final ParseField FORCE = new ParseField("force"); public static final ParseField MAX_PAGE_SEARCH_SIZE = new ParseField("max_page_search_size"); + public static final ParseField FIELD = new ParseField("field"); + public static final ParseField SYNC = new ParseField("sync"); + public static final ParseField TIME_BASED_SYNC = new ParseField("time"); + public static final ParseField DELAY = new ParseField("delay"); /** * Fields for checkpointing diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java index e6e6ac860e37c..7fe51feb2260a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameMessages.java @@ -38,7 +38,9 @@ public class DataFrameMessages { public static final String FAILED_TO_PARSE_TRANSFORM_CONFIGURATION = "Failed to parse transform configuration for data frame transform [{0}]"; public static final String FAILED_TO_PARSE_TRANSFORM_STATISTICS_CONFIGURATION = - "Failed to parse transform statistics for data frame transform [{0}]"; + "Failed to parse transform statistics for data frame transform [{0}]"; + public static final String FAILED_TO_LOAD_TRANSFORM_CHECKPOINT = + "Failed to load data frame transform configuration for transform [{0}]"; public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_NO_TRANSFORM = "Data frame transform configuration must specify exactly 1 function"; public static final String DATA_FRAME_TRANSFORM_CONFIGURATION_PIVOT_NO_GROUP_BY = diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java new file mode 100644 index 0000000000000..9eacfc5ff1eae --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/DataFrameNamedXContentProvider.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe; + +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; + +import java.util.Arrays; +import java.util.List; + +public class DataFrameNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + return Arrays.asList( + new NamedXContentRegistry.Entry(SyncConfig.class, + DataFrameField.TIME_BASED_SYNC, + TimeSyncConfig::parse)); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java index 19d4d6ab6eed1..2762e0507ef06 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfig.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.xpack.core.dataframe.DataFrameField; import org.elasticsearch.xpack.core.dataframe.DataFrameMessages; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfig; @@ -55,6 +56,7 @@ public class DataFrameTransformConfig extends AbstractDiffable create SourceConfig source = (SourceConfig) args[1]; DestConfig dest = (DestConfig) args[2]; - // ignored, only for internal storage: String docType = (String) args[3]; + SyncConfig syncConfig = (SyncConfig) args[3]; + // ignored, only for internal storage: String docType = (String) args[4]; // on strict parsing do not allow injection of headers, transform version, or create time if (lenient == false) { - validateStrictParsingParams(args[4], HEADERS.getPreferredName()); - validateStrictParsingParams(args[7], CREATE_TIME.getPreferredName()); - validateStrictParsingParams(args[8], VERSION.getPreferredName()); + validateStrictParsingParams(args[5], HEADERS.getPreferredName()); + validateStrictParsingParams(args[8], CREATE_TIME.getPreferredName()); + validateStrictParsingParams(args[9], VERSION.getPreferredName()); } @SuppressWarnings("unchecked") - Map headers = (Map) args[4]; + Map headers = (Map) args[5]; - PivotConfig pivotConfig = (PivotConfig) args[5]; - String description = (String)args[6]; + PivotConfig pivotConfig = (PivotConfig) args[6]; + String description = (String)args[7]; return new DataFrameTransformConfig(id, source, dest, + syncConfig, headers, pivotConfig, description, - (Instant)args[7], - (String)args[8]); + (Instant)args[8], + (String)args[9]); }); parser.declareString(optionalConstructorArg(), DataFrameField.ID); parser.declareObject(constructorArg(), (p, c) -> SourceConfig.fromXContent(p, lenient), DataFrameField.SOURCE); parser.declareObject(constructorArg(), (p, c) -> DestConfig.fromXContent(p, lenient), DataFrameField.DESTINATION); + parser.declareObject(optionalConstructorArg(), (p, c) -> parseSyncConfig(p, lenient), DataFrameField.SYNC); + parser.declareString(optionalConstructorArg(), DataFrameField.INDEX_DOC_TYPE); + parser.declareObject(optionalConstructorArg(), (p, c) -> p.mapStrings(), HEADERS); parser.declareObject(optionalConstructorArg(), (p, c) -> PivotConfig.fromXContent(p, lenient), PIVOT_TRANSFORM); parser.declareString(optionalConstructorArg(), DESCRIPTION); @@ -124,6 +131,14 @@ private static ConstructingObjectParser create return parser; } + private static SyncConfig parseSyncConfig(XContentParser parser, boolean ignoreUnknownFields) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + SyncConfig syncConfig = parser.namedObject(SyncConfig.class, parser.currentName(), ignoreUnknownFields); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return syncConfig; + } + public static String documentId(String transformId) { return NAME + "-" + transformId; } @@ -131,6 +146,7 @@ public static String documentId(String transformId) { DataFrameTransformConfig(final String id, final SourceConfig source, final DestConfig dest, + final SyncConfig syncConfig, final Map headers, final PivotConfig pivotConfig, final String description, @@ -139,6 +155,7 @@ public static String documentId(String transformId) { this.id = ExceptionsHelper.requireNonNull(id, DataFrameField.ID.getPreferredName()); this.source = ExceptionsHelper.requireNonNull(source, DataFrameField.SOURCE.getPreferredName()); this.dest = ExceptionsHelper.requireNonNull(dest, DataFrameField.DESTINATION.getPreferredName()); + this.syncConfig = syncConfig; this.setHeaders(headers == null ? Collections.emptyMap() : headers); this.pivotConfig = pivotConfig; this.description = description; @@ -157,10 +174,11 @@ public static String documentId(String transformId) { public DataFrameTransformConfig(final String id, final SourceConfig source, final DestConfig dest, + final SyncConfig syncConfig, final Map headers, final PivotConfig pivotConfig, final String description) { - this(id, source, dest, headers, pivotConfig, description, null, null); + this(id, source, dest, syncConfig, headers, pivotConfig, description, null, null); } public DataFrameTransformConfig(final StreamInput in) throws IOException { @@ -171,9 +189,11 @@ public DataFrameTransformConfig(final StreamInput in) throws IOException { pivotConfig = in.readOptionalWriteable(PivotConfig::new); description = in.readOptionalString(); if (in.getVersion().onOrAfter(Version.V_7_3_0)) { + syncConfig = in.readOptionalNamedWriteable(SyncConfig.class); createTime = in.readOptionalInstant(); transformVersion = in.readBoolean() ? Version.readVersion(in) : null; } else { + syncConfig = null; createTime = null; transformVersion = null; } @@ -191,6 +211,10 @@ public DestConfig getDestination() { return dest; } + public SyncConfig getSyncConfig() { + return syncConfig; + } + public Map getHeaders() { return headers; } @@ -233,6 +257,10 @@ public boolean isValid() { return false; } + if (syncConfig != null && syncConfig.isValid() == false) { + return false; + } + return source.isValid() && dest.isValid(); } @@ -245,8 +273,9 @@ public void writeTo(final StreamOutput out) throws IOException { out.writeOptionalWriteable(pivotConfig); out.writeOptionalString(description); if (out.getVersion().onOrAfter(Version.V_7_3_0)) { + out.writeOptionalNamedWriteable(syncConfig); out.writeOptionalInstant(createTime); - if (transformVersion != null) { + if (transformVersion != null) { out.writeBoolean(true); Version.writeVersion(transformVersion, out); } else { @@ -261,6 +290,11 @@ public XContentBuilder toXContent(final XContentBuilder builder, final Params pa builder.field(DataFrameField.ID.getPreferredName(), id); builder.field(DataFrameField.SOURCE.getPreferredName(), source); builder.field(DataFrameField.DESTINATION.getPreferredName(), dest); + if (syncConfig != null) { + builder.startObject(DataFrameField.SYNC.getPreferredName()); + builder.field(syncConfig.getWriteableName(), syncConfig); + builder.endObject(); + } if (pivotConfig != null) { builder.field(PIVOT_TRANSFORM.getPreferredName(), pivotConfig); } @@ -298,6 +332,7 @@ public boolean equals(Object other) { return Objects.equals(this.id, that.id) && Objects.equals(this.source, that.source) && Objects.equals(this.dest, that.dest) + && Objects.equals(this.syncConfig, that.syncConfig) && Objects.equals(this.headers, that.headers) && Objects.equals(this.pivotConfig, that.pivotConfig) && Objects.equals(this.description, that.description) @@ -307,7 +342,7 @@ public boolean equals(Object other) { @Override public int hashCode(){ - return Objects.hash(id, source, dest, headers, pivotConfig, description, createTime, transformVersion); + return Objects.hash(id, source, dest, syncConfig, headers, pivotConfig, description, createTime, transformVersion); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java new file mode 100644 index 0000000000000..19ff79ea7e0ee --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/SyncConfig.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.index.query.QueryBuilder; + +public interface SyncConfig extends ToXContentObject, NamedWriteable { + + /** + * Validate configuration + * + * @return true if valid + */ + boolean isValid(); + + QueryBuilder getRangeQuery(DataFrameTransformCheckpoint newCheckpoint); + + QueryBuilder getRangeQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java new file mode 100644 index 0000000000000..0490394d90b26 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfig.java @@ -0,0 +1,148 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.RangeQueryBuilder; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public class TimeSyncConfig implements SyncConfig { + + private static final String NAME = "data_frame_transform_pivot_sync_time"; + + private final String field; + private final TimeValue delay; + + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, lenient, + args -> { + String field = (String) args[0]; + TimeValue delay = args[1] != null ? (TimeValue) args[1] : TimeValue.ZERO; + + return new TimeSyncConfig(field, delay); + }); + + parser.declareString(constructorArg(), DataFrameField.FIELD); + parser.declareField(optionalConstructorArg(), + (p, c) -> TimeValue.parseTimeValue(p.textOrNull(), DataFrameField.DELAY.getPreferredName()), DataFrameField.DELAY, + ObjectParser.ValueType.STRING_OR_NULL); + + return parser; + } + + public TimeSyncConfig() { + this(null, null); + } + + public TimeSyncConfig(final String field, final TimeValue delay) { + this.field = ExceptionsHelper.requireNonNull(field, DataFrameField.FIELD.getPreferredName()); + this.delay = ExceptionsHelper.requireNonNull(delay, DataFrameField.DELAY.getPreferredName()); + } + + public TimeSyncConfig(StreamInput in) throws IOException { + this.field = in.readString(); + this.delay = in.readTimeValue(); + } + + public String getField() { + return field; + } + + public TimeValue getDelay() { + return delay; + } + + @Override + public boolean isValid() { + return true; + } + + @Override + public void writeTo(final StreamOutput out) throws IOException { + out.writeString(field); + out.writeTimeValue(delay); + } + + @Override + public XContentBuilder toXContent(final XContentBuilder builder, final Params params) throws IOException { + builder.startObject(); + builder.field(DataFrameField.FIELD.getPreferredName(), field); + if (delay.duration() > 0) { + builder.field(DataFrameField.DELAY.getPreferredName(), delay.getStringRep()); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + final TimeSyncConfig that = (TimeSyncConfig) other; + + return Objects.equals(this.field, that.field) + && Objects.equals(this.delay, that.delay); + } + + @Override + public int hashCode(){ + return Objects.hash(field, delay); + } + + @Override + public String toString() { + return Strings.toString(this, true, true); + } + + public static TimeSyncConfig parse(final XContentParser parser) { + return LENIENT_PARSER.apply(parser, null); + } + + public static TimeSyncConfig fromXContent(final XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + @Override + public String getWriteableName() { + return DataFrameField.TIME_BASED_SYNC.getPreferredName(); + } + + @Override + public QueryBuilder getRangeQuery(DataFrameTransformCheckpoint newCheckpoint) { + return new RangeQueryBuilder(field).lt(newCheckpoint.getTimeUpperBound()).format("epoch_millis"); + } + + @Override + public QueryBuilder getRangeQuery(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint) { + return new RangeQueryBuilder(field).gte(oldCheckpoint.getTimeUpperBound()).lt(newCheckpoint.getTimeUpperBound()) + .format("epoch_millis"); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java index a3861ef65f648..e38915c0beac6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/DateHistogramGroupSource.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder; import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramInterval; @@ -21,6 +22,7 @@ import java.time.ZoneId; import java.time.ZoneOffset; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -320,4 +322,15 @@ public boolean equals(Object other) { public int hashCode() { return Objects.hash(field, interval, timeZone, format); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + // no need for an extra range filter as this is already done by checkpoints + return null; + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java index 737590a0cc197..372f4ad99b608 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/HistogramGroupSource.java @@ -11,9 +11,11 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; import java.io.IOException; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -99,4 +101,15 @@ public boolean equals(Object other) { public int hashCode() { return Objects.hash(field, interval); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + // histograms are simple and cheap, so we skip this optimization + return null; + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java index ab2f7d489ac9a..038299bfd8326 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/PivotConfig.java @@ -100,12 +100,16 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public void toCompositeAggXContent(XContentBuilder builder, Params params) throws IOException { + public void toCompositeAggXContent(XContentBuilder builder, boolean forChangeDetection) throws IOException { builder.startObject(); builder.field(CompositeAggregationBuilder.SOURCES_FIELD_NAME.getPreferredName()); builder.startArray(); for (Entry groupBy : groups.getGroups().entrySet()) { + // some group source do not implement change detection or not makes no sense, skip those + if (forChangeDetection && groupBy.getValue().supportsIncrementalBucketUpdate() == false) { + continue; + } builder.startObject(); builder.startObject(groupBy.getKey()); builder.field(groupBy.getValue().getType().value(), groupBy.getValue()); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java index 0a4cf2579460e..ff1f9c3d54ac8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/SingleGroupSource.java @@ -14,10 +14,12 @@ import org.elasticsearch.common.xcontent.AbstractObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; import java.io.IOException; import java.util.Locale; import java.util.Objects; +import java.util.Set; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; @@ -94,6 +96,10 @@ public void writeTo(StreamOutput out) throws IOException { public abstract Type getType(); + public abstract boolean supportsIncrementalBucketUpdate(); + + public abstract QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets); + public String getField() { return field; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java index d4585a611b367..891b160da0762 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/dataframe/transforms/pivot/TermsGroupSource.java @@ -9,8 +9,11 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.TermsQueryBuilder; import java.io.IOException; +import java.util.Set; /* * A terms aggregation source for group_by @@ -47,4 +50,14 @@ public Type getType() { public static TermsGroupSource fromXContent(final XContentParser parser, boolean lenient) throws IOException { return lenient ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); } + + @Override + public QueryBuilder getIncrementalBucketUpdateFilterQuery(Set changedBuckets) { + return new TermsQueryBuilder(field, changedBuckets); + } + + @Override + public boolean supportsIncrementalBucketUpdate() { + return true; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java index 6b5ba086c6fe0..5c3da41df7349 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java @@ -13,7 +13,7 @@ public final class MachineLearningField { public static final Setting AUTODETECT_PROCESS = Setting.boolSetting("xpack.ml.autodetect_process", true, Setting.Property.NodeScope); public static final Setting MAX_MODEL_MEMORY_LIMIT = - Setting.memorySizeSetting("xpack.ml.max_model_memory_limit", new ByteSizeValue(0), + Setting.memorySizeSetting("xpack.ml.max_model_memory_limit", ByteSizeValue.ZERO, Setting.Property.Dynamic, Setting.Property.NodeScope); public static final TimeValue STATE_PERSIST_RESTORE_TIMEOUT = TimeValue.timeValueMinutes(30); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index cd32505a48e3e..9ac63f026b089 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -11,6 +11,8 @@ import org.elasticsearch.persistent.PersistentTasksClusterService; import org.elasticsearch.persistent.PersistentTasksCustomMetaData; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsTaskState; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.config.JobTaskState; @@ -23,9 +25,11 @@ public final class MlTasks { public static final String JOB_TASK_NAME = "xpack/ml/job"; public static final String DATAFEED_TASK_NAME = "xpack/ml/datafeed"; + public static final String DATA_FRAME_ANALYTICS_TASK_NAME = "xpack/ml/data_frame/analytics"; public static final String JOB_TASK_ID_PREFIX = "job-"; public static final String DATAFEED_TASK_ID_PREFIX = "datafeed-"; + public static final String DATA_FRAME_ANALYTICS_TASK_ID_PREFIX = "data_frame_analytics-"; public static final PersistentTasksCustomMetaData.Assignment AWAITING_UPGRADE = new PersistentTasksCustomMetaData.Assignment(null, @@ -50,6 +54,17 @@ public static String datafeedTaskId(String datafeedId) { return DATAFEED_TASK_ID_PREFIX + datafeedId; } + /** + * Namespaces the task ids for data frame analytics. + */ + public static String dataFrameAnalyticsTaskId(String id) { + return DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + id; + } + + public static String dataFrameAnalyticsIdFromTaskId(String taskId) { + return taskId.replaceFirst(DATA_FRAME_ANALYTICS_TASK_ID_PREFIX, ""); + } + @Nullable public static PersistentTasksCustomMetaData.PersistentTask getJobTask(String jobId, @Nullable PersistentTasksCustomMetaData tasks) { return tasks == null ? null : tasks.getTask(jobTaskId(jobId)); @@ -61,6 +76,12 @@ public static PersistentTasksCustomMetaData.PersistentTask getDatafeedTask(St return tasks == null ? null : tasks.getTask(datafeedTaskId(datafeedId)); } + @Nullable + public static PersistentTasksCustomMetaData.PersistentTask getDataFrameAnalyticsTask(String analyticsId, + @Nullable PersistentTasksCustomMetaData tasks) { + return tasks == null ? null : tasks.getTask(dataFrameAnalyticsTaskId(analyticsId)); + } + /** * Note that the return value of this method does NOT take node relocations into account. * Use {@link #getJobStateModifiedForReassignments} to return a value adjusted to the most @@ -120,6 +141,16 @@ public static DatafeedState getDatafeedState(String datafeedId, @Nullable Persis } } + public static DataFrameAnalyticsState getDataFrameAnalyticsState(String analyticsId, @Nullable PersistentTasksCustomMetaData tasks) { + PersistentTasksCustomMetaData.PersistentTask task = getDataFrameAnalyticsTask(analyticsId, tasks); + if (task != null && task.getState() != null) { + DataFrameAnalyticsTaskState taskState = (DataFrameAnalyticsTaskState) task.getState(); + return taskState.getState(); + } else { + return DataFrameAnalyticsState.STOPPED; + } + } + /** * The job Ids of anomaly detector job tasks. * All anomaly detector jobs are returned regardless of the status of the diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..9a777b23a4bb8 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/DeleteDataFrameAnalyticsAction.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeOperationRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class DeleteDataFrameAnalyticsAction extends Action { + + public static final DeleteDataFrameAnalyticsAction INSTANCE = new DeleteDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/delete"; + + private DeleteDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public AcknowledgedResponse newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return AcknowledgedResponse::new; + } + + public static class Request extends AcknowledgedRequest implements ToXContentFragment { + + private String id; + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + } + + public Request() {} + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public String getId() { + return id; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DeleteDataFrameAnalyticsAction.Request request = (DeleteDataFrameAnalyticsAction.Request) o; + return Objects.equals(id, request.id); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + + public static class RequestBuilder extends MasterNodeOperationRequestBuilder { + + protected RequestBuilder(ElasticsearchClient client, DeleteDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java new file mode 100644 index 0000000000000..eec58428d55cd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameAction.java @@ -0,0 +1,215 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +public class EvaluateDataFrameAction extends Action { + + public static final EvaluateDataFrameAction INSTANCE = new EvaluateDataFrameAction(); + public static final String NAME = "cluster:monitor/xpack/ml/data_frame/evaluate"; + + private EvaluateDataFrameAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(); + } + + public static class Request extends ActionRequest implements ToXContentObject { + + private static final ParseField INDEX = new ParseField("index"); + private static final ParseField EVALUATION = new ParseField("evaluation"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + a -> new Request((List) a[0], (Evaluation) a[1])); + + static { + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), INDEX); + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> parseEvaluation(p), EVALUATION); + } + + private static Evaluation parseEvaluation(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + Evaluation evaluation = parser.namedObject(Evaluation.class, parser.currentName(), null); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return evaluation; + } + + public static Request parseRequest(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private String[] indices; + private Evaluation evaluation; + + private Request(List indices, Evaluation evaluation) { + setIndices(indices); + setEvaluation(evaluation); + } + + public Request() { + } + + public String[] getIndices() { + return indices; + } + + public final void setIndices(List indices) { + ExceptionsHelper.requireNonNull(indices, INDEX); + if (indices.isEmpty()) { + throw ExceptionsHelper.badRequestException("At least one index must be specified"); + } + this.indices = indices.toArray(new String[indices.size()]); + } + + public Evaluation getEvaluation() { + return evaluation; + } + + public final void setEvaluation(Evaluation evaluation) { + this.evaluation = ExceptionsHelper.requireNonNull(evaluation, EVALUATION); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + indices = in.readStringArray(); + evaluation = in.readNamedWriteable(Evaluation.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeStringArray(indices); + out.writeNamedWriteable(evaluation); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.array(INDEX.getPreferredName(), indices); + builder.startObject(EVALUATION.getPreferredName()); + builder.field(evaluation.getName(), evaluation); + builder.endObject(); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(indices), evaluation); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request that = (Request) o; + return Arrays.equals(indices, that.indices) && Objects.equals(evaluation, that.evaluation); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client) { + super(client, INSTANCE, new Request()); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private String evaluationName; + private List metrics; + + public Response() { + } + + public Response(String evaluationName, List metrics) { + this.evaluationName = Objects.requireNonNull(evaluationName); + this.metrics = Objects.requireNonNull(metrics); + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + this.evaluationName = in.readString(); + this.metrics = in.readNamedWriteableList(EvaluationMetricResult.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(evaluationName); + out.writeList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startObject(evaluationName); + for (EvaluationMetricResult metric : metrics) { + builder.field(metric.getName(), metric); + } + builder.endObject(); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(evaluationName, metrics); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Response other = (Response) obj; + return Objects.equals(evaluationName, other.evaluationName) && Objects.equals(metrics, other.metrics); + } + + @Override + public final String toString() { + return Strings.toString(this); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..92233fbb27692 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsAction.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest; +import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; + +import java.io.IOException; +import java.util.Collections; + +public class GetDataFrameAnalyticsAction extends Action { + + public static final GetDataFrameAnalyticsAction INSTANCE = new GetDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/get"; + + private GetDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(new QueryPage<>(Collections.emptyList(), 0, Response.RESULTS_FIELD)); + } + + public static class Request extends AbstractGetResourcesRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + public Request() { + setAllowNoResources(true); + } + + public Request(String id) { + setResourceId(id); + setAllowNoResources(true); + } + + public Request(StreamInput in) throws IOException { + readFrom(in); + } + + @Override + public String getResourceIdField() { + return DataFrameAnalyticsConfig.ID.getPreferredName(); + } + } + + public static class Response extends AbstractGetResourcesResponse { + + public static final ParseField RESULTS_FIELD = new ParseField("data_frame_analytics"); + + public Response() {} + + public Response(QueryPage analytics) { + super(analytics); + } + + @Override + protected Reader getReader() { + return DataFrameAnalyticsConfig::new; + } + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client) { + super(client, INSTANCE, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java new file mode 100644 index 0000000000000..b14feaa8839f5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java @@ -0,0 +1,321 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.TaskOperationFailure; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class GetDataFrameAnalyticsStatsAction extends Action { + + public static final GetDataFrameAnalyticsStatsAction INSTANCE = new GetDataFrameAnalyticsStatsAction(); + public static final String NAME = "cluster:monitor/xpack/ml/data_frame/analytics/stats/get"; + + private GetDataFrameAnalyticsStatsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return Response::new; + } + + public static class Request extends BaseTasksRequest { + + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private String id; + private boolean allowNoMatch = true; + private PageParams pageParams = PageParams.defaultParams(); + + // Used internally to store the expanded IDs + private List expandedIds = Collections.emptyList(); + + public Request(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID.getPreferredName()); + this.expandedIds = Collections.singletonList(id); + } + + public Request() {} + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + allowNoMatch = in.readBoolean(); + pageParams = in.readOptionalWriteable(PageParams::new); + expandedIds = in.readStringList(); + } + + public void setExpandedIds(List expandedIds) { + this.expandedIds = Objects.requireNonNull(expandedIds); + } + + public List getExpandedIds() { + return expandedIds; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeBoolean(allowNoMatch); + out.writeOptionalWriteable(pageParams); + out.writeStringCollection(expandedIds); + } + + public void setId(String id) { + this.id = id; + } + + public String getId() { + return id; + } + + public boolean isAllowNoMatch() { + return allowNoMatch; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + + public void setPageParams(PageParams pageParams) { + this.pageParams = pageParams; + } + + public PageParams getPageParams() { + return pageParams; + } + + @Override + public boolean match(Task task) { + return expandedIds.stream().anyMatch(expandedId -> StartDataFrameAnalyticsAction.TaskMatcher.match(task, expandedId)); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public int hashCode() { + return Objects.hash(id, allowNoMatch, pageParams); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Request other = (Request) obj; + return Objects.equals(id, other.id) && allowNoMatch == other.allowNoMatch && Objects.equals(pageParams, other.pageParams); + } + } + + public static class RequestBuilder extends ActionRequestBuilder { + + public RequestBuilder(ElasticsearchClient client, GetDataFrameAnalyticsStatsAction action) { + super(client, action, new Request()); + } + } + + public static class Response extends BaseTasksResponse implements ToXContentObject { + + public static class Stats implements ToXContentObject, Writeable { + + private final String id; + private final DataFrameAnalyticsState state; + @Nullable + private final Integer progressPercentage; + @Nullable + private final DiscoveryNode node; + @Nullable + private final String assignmentExplanation; + + public Stats(String id, DataFrameAnalyticsState state, @Nullable Integer progressPercentage, + @Nullable DiscoveryNode node, @Nullable String assignmentExplanation) { + this.id = Objects.requireNonNull(id); + this.state = Objects.requireNonNull(state); + this.progressPercentage = progressPercentage; + this.node = node; + this.assignmentExplanation = assignmentExplanation; + } + + public Stats(StreamInput in) throws IOException { + id = in.readString(); + state = DataFrameAnalyticsState.fromStream(in); + progressPercentage = in.readOptionalInt(); + node = in.readOptionalWriteable(DiscoveryNode::new); + assignmentExplanation = in.readOptionalString(); + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsState getState() { + return state; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + // TODO: Have callers wrap the content with an object as they choose rather than forcing it upon them + builder.startObject(); + { + toUnwrappedXContent(builder); + } + return builder.endObject(); + } + + public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOException { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + builder.field("state", state.toString()); + if (progressPercentage != null) { + builder.field("progress_percent", progressPercentage); + } + if (node != null) { + builder.startObject("node"); + builder.field("id", node.getId()); + builder.field("name", node.getName()); + builder.field("ephemeral_id", node.getEphemeralId()); + builder.field("transport_address", node.getAddress().toString()); + + builder.startObject("attributes"); + for (Map.Entry entry : node.getAttributes().entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + builder.endObject(); + } + if (assignmentExplanation != null) { + builder.field("assignment_explanation", assignmentExplanation); + } + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + state.writeTo(out); + out.writeOptionalInt(progressPercentage); + out.writeOptionalWriteable(node); + out.writeOptionalString(assignmentExplanation); + } + + @Override + public int hashCode() { + return Objects.hash(id, state, progressPercentage, node, assignmentExplanation); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Stats other = (Stats) obj; + return Objects.equals(id, other.id) + && Objects.equals(this.state, other.state) + && Objects.equals(this.node, other.node) + && Objects.equals(this.assignmentExplanation, other.assignmentExplanation); + } + } + + private QueryPage stats; + + public Response(QueryPage stats) { + this(Collections.emptyList(), Collections.emptyList(), stats); + } + + public Response(List taskFailures, List nodeFailures, + QueryPage stats) { + super(taskFailures, nodeFailures); + this.stats = stats; + } + + public Response(StreamInput in) throws IOException { + super(in); + stats = new QueryPage<>(in, Stats::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + stats.writeTo(out); + } + + public QueryPage getResponse() { + return stats; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + stats.doXContentBody(builder, params); + builder.endObject(); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(stats); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Response other = (Response) obj; + return Objects.equals(stats, other.stats); + } + + @Override + public final String toString() { + return Strings.toString(this); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..e447aa70109e7 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsAction.java @@ -0,0 +1,153 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.MasterNodeOperationRequestBuilder; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; + +import java.io.IOException; +import java.util.Objects; + +public class PutDataFrameAnalyticsAction extends Action { + + public static final PutDataFrameAnalyticsAction INSTANCE = new PutDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/put"; + + private PutDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + return new Response(); + } + + public static class Request extends AcknowledgedRequest implements ToXContentObject { + + public static Request parseRequest(String id, XContentParser parser) { + DataFrameAnalyticsConfig.Builder config = DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null); + if (config.getId() == null) { + config.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(config.getId())) { + // If we have both URI and body ID, they must be identical + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + config.getId(), id)); + } + + return new PutDataFrameAnalyticsAction.Request(config.build()); + } + + private DataFrameAnalyticsConfig config; + + public Request() {} + + public Request(DataFrameAnalyticsConfig config) { + this.config = config; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + config = new DataFrameAnalyticsConfig(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + config.writeTo(out); + } + + public DataFrameAnalyticsConfig getConfig() { + return config; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + config.toXContent(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PutDataFrameAnalyticsAction.Request request = (PutDataFrameAnalyticsAction.Request) o; + return Objects.equals(config, request.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } + } + + public static class Response extends ActionResponse implements ToXContentObject { + + private DataFrameAnalyticsConfig config; + + public Response(DataFrameAnalyticsConfig config) { + this.config = config; + } + + Response() {} + + @Override + public void readFrom(StreamInput in) throws IOException { + super.readFrom(in); + config = new DataFrameAnalyticsConfig(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + config.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + config.toXContent(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(config, response.config); + } + + @Override + public int hashCode() { + return Objects.hash(config); + } + } + + public static class RequestBuilder extends MasterNodeOperationRequestBuilder { + + protected RequestBuilder(ElasticsearchClient client, PutDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..d722198bdfae6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.Version; +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.MasterNodeRequest; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Objects; + +public class StartDataFrameAnalyticsAction extends Action { + + public static final StartDataFrameAnalyticsAction INSTANCE = new StartDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/start"; + + private StartDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public AcknowledgedResponse newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return AcknowledgedResponse::new; + } + + public static class Request extends MasterNodeRequest implements ToXContentObject { + + public static final ParseField TIMEOUT = new ParseField("timeout"); + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); + + static { + PARSER.declareString((request, id) -> request.id = id, DataFrameAnalyticsConfig.ID); + PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); + } + + public static Request parseRequest(String id, XContentParser parser) { + Request request = PARSER.apply(parser, null); + if (request.getId() == null) { + request.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(request.getId())) { + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + request.getId(), id)); + } + return request; + } + + private String id; + private TimeValue timeout = TimeValue.timeValueSeconds(20); + + public Request(String id) { + setId(id); + } + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + timeout = in.readTimeValue(); + } + + public Request() {} + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public String getId() { + return id; + } + + public void setTimeout(TimeValue timeout) { + this.timeout = timeout; + } + + public TimeValue getTimeout() { + return timeout; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeTimeValue(timeout); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (id != null) { + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + } + builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep()); + return builder; + } + + @Override + public int hashCode() { + return Objects.hash(id, timeout); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + StartDataFrameAnalyticsAction.Request other = (StartDataFrameAnalyticsAction.Request) obj; + return Objects.equals(id, other.id) && Objects.equals(timeout, other.timeout); + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client, StartDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } + + public static class TaskParams implements XPackPlugin.XPackPersistentTaskParams { + + // TODO Update to first released version + public static final Version VERSION_INTRODUCED = Version.V_7_1_0; + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, a -> new TaskParams((String) a[0])); + + public static TaskParams fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private String id; + + public TaskParams(String id) { + this.id = Objects.requireNonNull(id); + } + + public TaskParams(StreamInput in) throws IOException { + this.id = in.readString(); + } + + public String getId() { + return id; + } + + @Override + public String getWriteableName() { + return MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return VERSION_INTRODUCED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id); + builder.endObject(); + return builder; + } + } + + public interface TaskMatcher { + + static boolean match(Task task, String expectedId) { + if (task instanceof TaskMatcher) { + if (MetaData.ALL.equals(expectedId)) { + return true; + } + String expectedDescription = MlTasks.DATA_FRAME_ANALYTICS_TASK_ID_PREFIX + expectedId; + return expectedDescription.equals(task.getDescription()); + } + return false; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java new file mode 100644 index 0000000000000..43d382147fd64 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsAction.java @@ -0,0 +1,223 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.Action; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.support.tasks.BaseTasksRequest; +import org.elasticsearch.action.support.tasks.BaseTasksResponse; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +public class StopDataFrameAnalyticsAction extends Action { + + public static final StopDataFrameAnalyticsAction INSTANCE = new StopDataFrameAnalyticsAction(); + public static final String NAME = "cluster:admin/xpack/ml/data_frame/analytics/stop"; + + private StopDataFrameAnalyticsAction() { + super(NAME); + } + + @Override + public Response newResponse() { + throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable"); + } + + @Override + public Writeable.Reader getResponseReader() { + return Response::new; + } + + public static class Request extends BaseTasksRequest implements ToXContentObject { + + public static final ParseField TIMEOUT = new ParseField("timeout"); + public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match"); + + private static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); + + static { + PARSER.declareString((request, id) -> request.id = id, DataFrameAnalyticsConfig.ID); + PARSER.declareString((request, val) -> request.setTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); + } + + public static Request parseRequest(String id, XContentParser parser) { + Request request = PARSER.apply(parser, null); + if (request.getId() == null) { + request.setId(id); + } else if (!Strings.isNullOrEmpty(id) && !id.equals(request.getId())) { + throw new IllegalArgumentException(Messages.getMessage(Messages.INCONSISTENT_ID, DataFrameAnalyticsConfig.ID, + request.getId(), id)); + } + return request; + } + + private String id; + private Set expandedIds = Collections.emptySet(); + private boolean allowNoMatch = true; + + public Request(String id) { + setId(id); + } + + public Request(StreamInput in) throws IOException { + super(in); + id = in.readString(); + expandedIds = new HashSet<>(Arrays.asList(in.readStringArray())); + allowNoMatch = in.readBoolean(); + } + + public Request() {} + + public final void setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, DataFrameAnalyticsConfig.ID); + } + + public String getId() { + return id; + } + + @Nullable + public Set getExpandedIds() { + return expandedIds; + } + + public void setExpandedIds(Set expandedIds) { + this.expandedIds = Objects.requireNonNull(expandedIds); + } + + public boolean allowNoMatch() { + return allowNoMatch; + } + + public void setAllowNoMatch(boolean allowNoMatch) { + this.allowNoMatch = allowNoMatch; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(id); + out.writeStringArray(expandedIds.toArray(new String[0])); + out.writeBoolean(allowNoMatch); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder + .startObject() + .field(DataFrameAnalyticsConfig.ID.getPreferredName(), id) + .field(ALLOW_NO_MATCH.getPreferredName(), allowNoMatch) + .endObject(); + } + + @Override + public int hashCode() { + return Objects.hash(id, getTimeout(), expandedIds, allowNoMatch); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || obj.getClass() != getClass()) { + return false; + } + StopDataFrameAnalyticsAction.Request other = (StopDataFrameAnalyticsAction.Request) obj; + return Objects.equals(id, other.id) + && Objects.equals(getTimeout(), other.getTimeout()) + && Objects.equals(expandedIds, other.expandedIds) + && allowNoMatch == other.allowNoMatch; + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + public static class Response extends BaseTasksResponse implements Writeable, ToXContentObject { + + private final boolean stopped; + + public Response(boolean stopped) { + super(null, null); + this.stopped = stopped; + } + + public Response(StreamInput in) throws IOException { + super(in); + stopped = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeBoolean(stopped); + } + + public boolean isStopped() { + return stopped; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + toXContentCommon(builder, params); + builder.field("stopped", stopped); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + Response response = (Response) o; + return stopped == response.stopped; + } + + @Override + public int hashCode() { + return Objects.hash(stopped); + } + } + + static class RequestBuilder extends ActionRequestBuilder { + + RequestBuilder(ElasticsearchClient client, StopDataFrameAnalyticsAction action) { + super(client, action, new Request()); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java index 6889f5199526d..54d4869bf2824 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfig.java @@ -30,6 +30,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import org.elasticsearch.xpack.core.ml.utils.time.TimeUtils; @@ -122,7 +123,7 @@ private static ObjectParser createParser(boolean ignoreUnknownFie parser.declareString((builder, val) -> builder.setFrequency(TimeValue.parseTimeValue(val, FREQUENCY.getPreferredName())), FREQUENCY); parser.declareObject(Builder::setQueryProvider, - (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields), + (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields, Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT), QUERY); parser.declareObject(Builder::setAggregationsSafe, (p, c) -> AggProvider.fromXContent(p, ignoreUnknownFields), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java index c1005bb971a56..9f38d0323f244 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java @@ -22,7 +22,9 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; @@ -53,7 +55,8 @@ public class DatafeedUpdate implements Writeable, ToXContentObject { TimeValue.parseTimeValue(val, DatafeedConfig.QUERY_DELAY.getPreferredName())), DatafeedConfig.QUERY_DELAY); PARSER.declareString((builder, val) -> builder.setFrequency( TimeValue.parseTimeValue(val, DatafeedConfig.FREQUENCY.getPreferredName())), DatafeedConfig.FREQUENCY); - PARSER.declareObject(Builder::setQuery, (p, c) -> QueryProvider.fromXContent(p, false), DatafeedConfig.QUERY); + PARSER.declareObject(Builder::setQuery, (p, c) -> QueryProvider.fromXContent(p, false, Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT), + DatafeedConfig.QUERY); PARSER.declareObject(Builder::setAggregationsSafe, (p, c) -> AggProvider.fromXContent(p, false), DatafeedConfig.AGGREGATIONS); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java new file mode 100644 index 0000000000000..0e9acdd44a2fe --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java @@ -0,0 +1,312 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.OBJECT_ARRAY_BOOLEAN_OR_STRING; +import static org.elasticsearch.common.xcontent.ObjectParser.ValueType.VALUE; + +public class DataFrameAnalyticsConfig implements ToXContentObject, Writeable { + + public static final String TYPE = "data_frame_analytics_config"; + + public static final ByteSizeValue DEFAULT_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.GB); + public static final ByteSizeValue MIN_MODEL_MEMORY_LIMIT = new ByteSizeValue(1, ByteSizeUnit.MB); + public static final ByteSizeValue PROCESS_MEMORY_OVERHEAD = new ByteSizeValue(20, ByteSizeUnit.MB); + + public static final ParseField ID = new ParseField("id"); + public static final ParseField SOURCE = new ParseField("source"); + public static final ParseField DEST = new ParseField("dest"); + public static final ParseField ANALYSIS = new ParseField("analysis"); + public static final ParseField CONFIG_TYPE = new ParseField("config_type"); + public static final ParseField ANALYZED_FIELDS = new ParseField("analyzed_fields"); + public static final ParseField MODEL_MEMORY_LIMIT = new ParseField("model_memory_limit"); + public static final ParseField HEADERS = new ParseField("headers"); + + public static final ObjectParser STRICT_PARSER = createParser(false); + public static final ObjectParser LENIENT_PARSER = createParser(true); + + public static ObjectParser createParser(boolean ignoreUnknownFields) { + ObjectParser parser = new ObjectParser<>(TYPE, ignoreUnknownFields, Builder::new); + + parser.declareString((c, s) -> {}, CONFIG_TYPE); + parser.declareString(Builder::setId, ID); + parser.declareObject(Builder::setSource, DataFrameAnalyticsSource.createParser(ignoreUnknownFields), SOURCE); + parser.declareObject(Builder::setDest, DataFrameAnalyticsDest.createParser(ignoreUnknownFields), DEST); + parser.declareObject(Builder::setAnalysis, (p, c) -> parseAnalysis(p, ignoreUnknownFields), ANALYSIS); + parser.declareField(Builder::setAnalyzedFields, + (p, c) -> FetchSourceContext.fromXContent(p), + ANALYZED_FIELDS, + OBJECT_ARRAY_BOOLEAN_OR_STRING); + parser.declareField(Builder::setModelMemoryLimit, + (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), MODEL_MEMORY_LIMIT.getPreferredName()), MODEL_MEMORY_LIMIT, VALUE); + if (ignoreUnknownFields) { + // Headers are not parsed by the strict (config) parser, so headers supplied in the _body_ of a REST request will be rejected. + // (For config, headers are explicitly transferred from the auth headers by code in the put data frame actions.) + parser.declareObject(Builder::setHeaders, (p, c) -> p.mapStrings(), HEADERS); + } + return parser; + } + + private static DataFrameAnalysis parseAnalysis(XContentParser parser, boolean ignoreUnknownFields) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser::getTokenLocation); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.FIELD_NAME, parser.nextToken(), parser::getTokenLocation); + DataFrameAnalysis analysis = parser.namedObject(DataFrameAnalysis.class, parser.currentName(), ignoreUnknownFields); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.nextToken(), parser::getTokenLocation); + return analysis; + } + + private final String id; + private final DataFrameAnalyticsSource source; + private final DataFrameAnalyticsDest dest; + private final DataFrameAnalysis analysis; + private final FetchSourceContext analyzedFields; + /** + * This may be null up to the point of persistence, as the relationship with xpack.ml.max_model_memory_limit + * depends on whether the user explicitly set the value or if the default was requested. null indicates + * the default was requested, which in turn means a default higher than the maximum is silently capped. + * A non-null value higher than xpack.ml.max_model_memory_limit will cause a + * validation error even if it is equal to the default value. This behaviour matches what is done in + * {@link org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits}. + */ + private final ByteSizeValue modelMemoryLimit; + private final Map headers; + + public DataFrameAnalyticsConfig(String id, DataFrameAnalyticsSource source, DataFrameAnalyticsDest dest, + DataFrameAnalysis analysis, Map headers, ByteSizeValue modelMemoryLimit, + FetchSourceContext analyzedFields) { + this.id = ExceptionsHelper.requireNonNull(id, ID); + this.source = ExceptionsHelper.requireNonNull(source, SOURCE); + this.dest = ExceptionsHelper.requireNonNull(dest, DEST); + this.analysis = ExceptionsHelper.requireNonNull(analysis, ANALYSIS); + this.analyzedFields = analyzedFields; + this.modelMemoryLimit = modelMemoryLimit; + this.headers = Collections.unmodifiableMap(headers); + } + + public DataFrameAnalyticsConfig(StreamInput in) throws IOException { + id = in.readString(); + source = new DataFrameAnalyticsSource(in); + dest = new DataFrameAnalyticsDest(in); + analysis = in.readNamedWriteable(DataFrameAnalysis.class); + this.analyzedFields = in.readOptionalWriteable(FetchSourceContext::new); + this.modelMemoryLimit = in.readOptionalWriteable(ByteSizeValue::new); + this.headers = Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)); + } + + public String getId() { + return id; + } + + public DataFrameAnalyticsSource getSource() { + return source; + } + + public DataFrameAnalyticsDest getDest() { + return dest; + } + + public DataFrameAnalysis getAnalysis() { + return analysis; + } + + public FetchSourceContext getAnalyzedFields() { + return analyzedFields; + } + + public ByteSizeValue getModelMemoryLimit() { + return modelMemoryLimit != null ? modelMemoryLimit : DEFAULT_MODEL_MEMORY_LIMIT; + } + + public Map getHeaders() { + return headers; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ID.getPreferredName(), id); + builder.field(SOURCE.getPreferredName(), source); + builder.field(DEST.getPreferredName(), dest); + + builder.startObject(ANALYSIS.getPreferredName()); + builder.field(analysis.getWriteableName(), analysis); + builder.endObject(); + + if (params.paramAsBoolean(ToXContentParams.INCLUDE_TYPE, false)) { + builder.field(CONFIG_TYPE.getPreferredName(), TYPE); + } + if (analyzedFields != null) { + builder.field(ANALYZED_FIELDS.getPreferredName(), analyzedFields); + } + builder.field(MODEL_MEMORY_LIMIT.getPreferredName(), getModelMemoryLimit().getStringRep()); + if (headers.isEmpty() == false && params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(HEADERS.getPreferredName(), headers); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + source.writeTo(out); + dest.writeTo(out); + out.writeNamedWriteable(analysis); + out.writeOptionalWriteable(analyzedFields); + out.writeOptionalWriteable(modelMemoryLimit); + out.writeMap(headers, StreamOutput::writeString, StreamOutput::writeString); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsConfig other = (DataFrameAnalyticsConfig) o; + return Objects.equals(id, other.id) + && Objects.equals(source, other.source) + && Objects.equals(dest, other.dest) + && Objects.equals(analysis, other.analysis) + && Objects.equals(headers, other.headers) + && Objects.equals(getModelMemoryLimit(), other.getModelMemoryLimit()) + && Objects.equals(analyzedFields, other.analyzedFields); + } + + @Override + public int hashCode() { + return Objects.hash(id, source, dest, analysis, headers, getModelMemoryLimit(), analyzedFields); + } + + public static String documentId(String id) { + return TYPE + "-" + id; + } + + public static class Builder { + + private String id; + private DataFrameAnalyticsSource source; + private DataFrameAnalyticsDest dest; + private DataFrameAnalysis analysis; + private FetchSourceContext analyzedFields; + private ByteSizeValue modelMemoryLimit; + private ByteSizeValue maxModelMemoryLimit; + private Map headers = Collections.emptyMap(); + + public Builder() {} + + public Builder(String id) { + setId(id); + } + + public Builder(ByteSizeValue maxModelMemoryLimit) { + this.maxModelMemoryLimit = maxModelMemoryLimit; + } + + public Builder(DataFrameAnalyticsConfig config) { + this(config, null); + } + + public Builder(DataFrameAnalyticsConfig config, ByteSizeValue maxModelMemoryLimit) { + this.id = config.id; + this.source = new DataFrameAnalyticsSource(config.source); + this.dest = new DataFrameAnalyticsDest(config.dest); + this.analysis = config.analysis; + this.headers = new HashMap<>(config.headers); + this.modelMemoryLimit = config.modelMemoryLimit; + this.maxModelMemoryLimit = maxModelMemoryLimit; + if (config.analyzedFields != null) { + this.analyzedFields = new FetchSourceContext(true, config.analyzedFields.includes(), config.analyzedFields.excludes()); + } + } + + public String getId() { + return id; + } + + public Builder setId(String id) { + this.id = ExceptionsHelper.requireNonNull(id, ID); + return this; + } + + public Builder setSource(DataFrameAnalyticsSource source) { + this.source = ExceptionsHelper.requireNonNull(source, SOURCE); + return this; + } + + public Builder setDest(DataFrameAnalyticsDest dest) { + this.dest = ExceptionsHelper.requireNonNull(dest, DEST); + return this; + } + + public Builder setAnalysis(DataFrameAnalysis analysis) { + this.analysis = ExceptionsHelper.requireNonNull(analysis, ANALYSIS); + return this; + } + + public Builder setAnalyzedFields(FetchSourceContext fields) { + this.analyzedFields = fields; + return this; + } + + public Builder setHeaders(Map headers) { + this.headers = headers; + return this; + } + + public Builder setModelMemoryLimit(ByteSizeValue modelMemoryLimit) { + if (modelMemoryLimit != null && modelMemoryLimit.compareTo(MIN_MODEL_MEMORY_LIMIT) < 0) { + throw new IllegalArgumentException("[" + MODEL_MEMORY_LIMIT.getPreferredName() + + "] must be at least [" + MIN_MODEL_MEMORY_LIMIT.getStringRep() + "]"); + } + this.modelMemoryLimit = modelMemoryLimit; + return this; + } + + private void applyMaxModelMemoryLimit() { + + boolean maxModelMemoryIsSet = maxModelMemoryLimit != null && maxModelMemoryLimit.getMb() > 0; + + if (modelMemoryLimit == null) { + // Default is silently capped if higher than limit + if (maxModelMemoryIsSet && DEFAULT_MODEL_MEMORY_LIMIT.compareTo(maxModelMemoryLimit) > 0) { + modelMemoryLimit = maxModelMemoryLimit; + } + } else if (maxModelMemoryIsSet && modelMemoryLimit.compareTo(maxModelMemoryLimit) > 0) { + // Explicit setting higher than limit is an error + throw ExceptionsHelper.badRequestException(Messages.getMessage(Messages.JOB_CONFIG_MODEL_MEMORY_LIMIT_GREATER_THAN_MAX, + modelMemoryLimit, maxModelMemoryLimit)); + } + } + + public DataFrameAnalyticsConfig build() { + applyMaxModelMemoryLimit(); + return new DataFrameAnalyticsConfig(id, source, dest, analysis, headers, modelMemoryLimit, analyzedFields); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java new file mode 100644 index 0000000000000..3bc435336f062 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDest.java @@ -0,0 +1,106 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.indices.InvalidIndexNameException; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; + +import static org.elasticsearch.cluster.metadata.MetaDataCreateIndexService.validateIndexOrAliasName; + +public class DataFrameAnalyticsDest implements Writeable, ToXContentObject { + + public static final ParseField INDEX = new ParseField("index"); + public static final ParseField RESULTS_FIELD = new ParseField("results_field"); + + private static final String DEFAULT_RESULTS_FIELD = "ml"; + + public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("data_frame_analytics_dest", + ignoreUnknownFields, a -> new DataFrameAnalyticsDest((String) a[0], (String) a[1])); + parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), RESULTS_FIELD); + return parser; + } + + private final String index; + private final String resultsField; + + public DataFrameAnalyticsDest(String index, @Nullable String resultsField) { + this.index = ExceptionsHelper.requireNonNull(index, INDEX); + if (index.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must be non-empty", INDEX); + } + this.resultsField = resultsField == null ? DEFAULT_RESULTS_FIELD : resultsField; + } + + public DataFrameAnalyticsDest(StreamInput in) throws IOException { + index = in.readString(); + resultsField = in.readString(); + } + + public DataFrameAnalyticsDest(DataFrameAnalyticsDest other) { + this.index = other.index; + this.resultsField = other.resultsField; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(index); + out.writeString(resultsField); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + builder.field(RESULTS_FIELD.getPreferredName(), resultsField); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsDest other = (DataFrameAnalyticsDest) o; + return Objects.equals(index, other.index) && Objects.equals(resultsField, other.resultsField); + } + + @Override + public int hashCode() { + return Objects.hash(index, resultsField); + } + + public String getIndex() { + return index; + } + + public String getResultsField() { + return resultsField; + } + + public void validate() { + if (index != null) { + validateIndexOrAliasName(index, InvalidIndexNameException::new); + if (index.toLowerCase(Locale.ROOT).equals(index) == false) { + throw new InvalidIndexNameException(index, "dest.index must be lowercase"); + } + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java new file mode 100644 index 0000000000000..a57de375f3989 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsSource.java @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; +import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class DataFrameAnalyticsSource implements Writeable, ToXContentObject { + + public static final ParseField INDEX = new ParseField("index"); + public static final ParseField QUERY = new ParseField("query"); + + public static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>("data_frame_analytics_source", + ignoreUnknownFields, a -> new DataFrameAnalyticsSource((String) a[0], (QueryProvider) a[1])); + parser.declareString(ConstructingObjectParser.constructorArg(), INDEX); + parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> QueryProvider.fromXContent(p, ignoreUnknownFields, Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT), QUERY); + return parser; + } + + private final String index; + private final QueryProvider queryProvider; + + public DataFrameAnalyticsSource(String index, @Nullable QueryProvider queryProvider) { + this.index = ExceptionsHelper.requireNonNull(index, INDEX); + if (index.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must be non-empty", INDEX); + } + this.queryProvider = queryProvider == null ? QueryProvider.defaultQuery() : queryProvider; + } + + public DataFrameAnalyticsSource(StreamInput in) throws IOException { + index = in.readString(); + queryProvider = QueryProvider.fromStream(in); + } + + public DataFrameAnalyticsSource(DataFrameAnalyticsSource other) { + this.index = other.index; + this.queryProvider = new QueryProvider(other.queryProvider); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(index); + queryProvider.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INDEX.getPreferredName(), index); + builder.field(QUERY.getPreferredName(), queryProvider.getQuery()); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (o == null || getClass() != o.getClass()) return false; + + DataFrameAnalyticsSource other = (DataFrameAnalyticsSource) o; + return Objects.equals(index, other.index) + && Objects.equals(queryProvider, other.queryProvider); + } + + @Override + public int hashCode() { + return Objects.hash(index, queryProvider); + } + + public String getIndex() { + return index; + } + + /** + * Get the fully parsed query from the semi-parsed stored {@code Map} + * + * @return Fully parsed query + */ + public QueryBuilder getParsedQuery() { + Exception exception = queryProvider.getParsingException(); + if (exception != null) { + if (exception instanceof RuntimeException) { + throw (RuntimeException) exception; + } else { + throw new ElasticsearchException(queryProvider.getParsingException()); + } + } + return queryProvider.getParsedQuery(); + } + + Exception getQueryParsingException() { + return queryProvider.getParsingException(); + } + + /** + * Calls the parser and returns any gathered deprecations + * + * @param namedXContentRegistry XContent registry to transform the lazily parsed query + * @return The deprecations from parsing the query + */ + public List getQueryDeprecations(NamedXContentRegistry namedXContentRegistry) { + List deprecations = new ArrayList<>(); + try { + XContentObjectTransformer.queryBuilderTransformer(namedXContentRegistry).fromMap(queryProvider.getQuery(), + deprecations); + } catch (Exception exception) { + // Certain thrown exceptions wrap up the real Illegal argument making it hard to determine cause for the user + if (exception.getCause() instanceof IllegalArgumentException) { + exception = (Exception) exception.getCause(); + } + throw ExceptionsHelper.badRequestException(Messages.DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT, exception); + } + return deprecations; + } + + public Map getQuery() { + return queryProvider.getQuery(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java new file mode 100644 index 0000000000000..d40df259eec57 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsState.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum DataFrameAnalyticsState implements Writeable { + + STARTED, REINDEXING, ANALYZING, STOPPING, STOPPED; + + public static DataFrameAnalyticsState fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static DataFrameAnalyticsState fromStream(StreamInput in) throws IOException { + return in.readEnum(DataFrameAnalyticsState.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeEnum(this); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java new file mode 100644 index 0000000000000..994faaaee6cc2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsTaskState.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.xpack.core.ml.MlTasks; + +import java.io.IOException; +import java.util.Objects; + +public class DataFrameAnalyticsTaskState implements PersistentTaskState { + + public static final String NAME = MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; + + private static ParseField STATE = new ParseField("state"); + private static ParseField ALLOCATION_ID = new ParseField("allocation_id"); + + private final DataFrameAnalyticsState state; + private final long allocationId; + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, + a -> new DataFrameAnalyticsTaskState((DataFrameAnalyticsState) a[0], (long) a[1])); + + static { + PARSER.declareField(ConstructingObjectParser.constructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return DataFrameAnalyticsState.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, STATE, ObjectParser.ValueType.STRING); + PARSER.declareLong(ConstructingObjectParser.constructorArg(), ALLOCATION_ID); + } + + public static DataFrameAnalyticsTaskState fromXContent(XContentParser parser) { + try { + return PARSER.parse(parser, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public DataFrameAnalyticsTaskState(DataFrameAnalyticsState state, long allocationId) { + this.state = Objects.requireNonNull(state); + this.allocationId = allocationId; + } + + public DataFrameAnalyticsTaskState(StreamInput in) throws IOException { + this.state = DataFrameAnalyticsState.fromStream(in); + this.allocationId = in.readLong(); + } + + public DataFrameAnalyticsState getState() { + return state; + } + + public boolean isStatusStale(PersistentTasksCustomMetaData.PersistentTask task) { + return allocationId != task.getAllocationId(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + state.writeTo(out); + out.writeLong(allocationId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(STATE.getPreferredName(), state.toString()); + builder.field(ALLOCATION_ID.getPreferredName(), allocationId); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DataFrameAnalyticsTaskState that = (DataFrameAnalyticsTaskState) o; + return allocationId == that.allocationId && + state == that.state; + } + + @Override + public int hashCode() { + return Objects.hash(state, allocationId); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java new file mode 100644 index 0000000000000..f21533d917602 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; + +import java.util.Map; + +public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { + + Map getParams(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java new file mode 100644 index 0000000000000..a48a23e4a8393 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/MlDataFrameAnalysisNamedXContentProvider.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; + +import java.util.ArrayList; +import java.util.List; + +public class MlDataFrameAnalysisNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + List namedXContent = new ArrayList<>(); + + namedXContent.add(new NamedXContentRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME, (p, c) -> { + boolean ignoreUnknownFields = (boolean) c; + return OutlierDetection.fromXContent(p, ignoreUnknownFields); + })); + + return namedXContent; + } + + public List getNamedWriteables() { + List namedWriteables = new ArrayList<>(); + + namedWriteables.add(new NamedWriteableRegistry.Entry(DataFrameAnalysis.class, OutlierDetection.NAME.getPreferredName(), + OutlierDetection::new)); + + return namedWriteables; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java new file mode 100644 index 0000000000000..91eb02b7bcdfe --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -0,0 +1,169 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.analyses; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; + +public class OutlierDetection implements DataFrameAnalysis { + + public static final ParseField NAME = new ParseField("outlier_detection"); + + public static final ParseField N_NEIGHBORS = new ParseField("n_neighbors"); + public static final ParseField METHOD = new ParseField("method"); + public static final ParseField MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE = + new ParseField("minimum_score_to_write_feature_influence"); + + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean lenient) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient, + a -> new OutlierDetection((Integer) a[0], (Method) a[1], (Double) a[2])); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), N_NEIGHBORS); + parser.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return Method.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, METHOD, ObjectParser.ValueType.STRING); + parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE); + return parser; + } + + public static OutlierDetection fromXContent(XContentParser parser, boolean ignoreUnknownFields) { + return ignoreUnknownFields ? LENIENT_PARSER.apply(parser, null) : STRICT_PARSER.apply(parser, null); + } + + private final Integer nNeighbors; + private final Method method; + private final Double minScoreToWriteFeatureInfluence; + + /** + * Constructs the outlier detection configuration + * @param nNeighbors The number of neighbors. Leave unspecified for dynamic detection. + * @param method The method. Leave unspecified for a dynamic mixture of methods. + * @param minScoreToWriteFeatureInfluence The min outlier score required to calculate feature influence. Defaults to 0.1. + */ + public OutlierDetection(@Nullable Integer nNeighbors, @Nullable Method method, @Nullable Double minScoreToWriteFeatureInfluence) { + if (nNeighbors != null && nNeighbors <= 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a positive integer", N_NEIGHBORS.getPreferredName()); + } + + if (minScoreToWriteFeatureInfluence != null && (minScoreToWriteFeatureInfluence < 0.0 || minScoreToWriteFeatureInfluence > 1.0)) { + throw ExceptionsHelper.badRequestException("[{}] must be in [0, 1]", + MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName()); + } + + this.nNeighbors = nNeighbors; + this.method = method; + this.minScoreToWriteFeatureInfluence = minScoreToWriteFeatureInfluence; + } + + /** + * Constructs the default outlier detection configuration + */ + public OutlierDetection() { + this(null, null, null); + } + + public OutlierDetection(StreamInput in) throws IOException { + nNeighbors = in.readOptionalVInt(); + method = in.readBoolean() ? in.readEnum(Method.class) : null; + minScoreToWriteFeatureInfluence = in.readOptionalDouble(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalVInt(nNeighbors); + + if (method != null) { + out.writeBoolean(true); + out.writeEnum(method); + } else { + out.writeBoolean(false); + } + + out.writeOptionalDouble(minScoreToWriteFeatureInfluence); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (nNeighbors != null) { + builder.field(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + builder.field(METHOD.getPreferredName(), method); + } + if (minScoreToWriteFeatureInfluence != null) { + builder.field(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OutlierDetection that = (OutlierDetection) o; + return Objects.equals(nNeighbors, that.nNeighbors) + && Objects.equals(method, that.method) + && Objects.equals(minScoreToWriteFeatureInfluence, that.minScoreToWriteFeatureInfluence); + } + + @Override + public int hashCode() { + return Objects.hash(nNeighbors, method, minScoreToWriteFeatureInfluence); + } + + @Override + public Map getParams() { + Map params = new HashMap<>(); + if (nNeighbors != null) { + params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); + } + if (method != null) { + params.put(METHOD.getPreferredName(), method); + } + if (minScoreToWriteFeatureInfluence != null) { + params.put(MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), minScoreToWriteFeatureInfluence); + } + return params; + } + + public enum Method { + LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN; + + public static Method fromString(String value) { + return Method.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java new file mode 100644 index 0000000000000..c01c19e33e865 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/Evaluation.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.search.builder.SearchSourceBuilder; + +import java.util.List; + +/** + * Defines an evaluation + */ +public interface Evaluation extends ToXContentObject, NamedWriteable { + + /** + * Returns the evaluation name + */ + String getName(); + + /** + * Builds the search required to collect data to compute the evaluation result + */ + SearchSourceBuilder buildSearch(); + + /** + * Computes the evaluation result + * @param searchResponse The search response required to compute the result + * @param listener A listener of the results + */ + void evaluate(SearchResponse searchResponse, ActionListener> listener); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java new file mode 100644 index 0000000000000..36b8adf9d4ea3 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/EvaluationMetricResult.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; + +/** + * The result of an evaluation metric + */ +public interface EvaluationMetricResult extends ToXContentObject, NamedWriteable { + + /** + * Returns the name of the metric + */ + String getName(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java new file mode 100644 index 0000000000000..f4a6dba88e3b1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/MlEvaluationNamedXContentProvider.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.AucRoc; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassification; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ConfusionMatrix; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Precision; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; + +import java.util.ArrayList; +import java.util.List; + +public class MlEvaluationNamedXContentProvider implements NamedXContentProvider { + + @Override + public List getNamedXContentParsers() { + List namedXContent = new ArrayList<>(); + + // Evaluations + namedXContent.add(new NamedXContentRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME, + BinarySoftClassification::fromXContent)); + + // Soft classification metrics + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME, AucRoc::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Precision.NAME, Precision::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, Recall.NAME, Recall::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME, + ConfusionMatrix::fromXContent)); + + return namedXContent; + } + + public List getNamedWriteables() { + List namedWriteables = new ArrayList<>(); + + // Evaluations + namedWriteables.add(new NamedWriteableRegistry.Entry(Evaluation.class, BinarySoftClassification.NAME.getPreferredName(), + BinarySoftClassification::new)); + + // Evaluation Metrics + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, AucRoc.NAME.getPreferredName(), + AucRoc::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Precision.NAME.getPreferredName(), + Precision::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, Recall.NAME.getPreferredName(), + Recall::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SoftClassificationMetric.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix::new)); + + // Evaluation Metrics Results + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, AucRoc.NAME.getPreferredName(), + AucRoc.Result::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ScoreByThresholdResult.NAME, + ScoreByThresholdResult::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetricResult.class, ConfusionMatrix.NAME.getPreferredName(), + ConfusionMatrix.Result::new)); + + return namedWriteables; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java new file mode 100644 index 0000000000000..facdcceea194f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AbstractConfusionMatrixMetric.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +abstract class AbstractConfusionMatrixMetric implements SoftClassificationMetric { + + public static final ParseField AT = new ParseField("at"); + + protected final double[] thresholds; + + protected AbstractConfusionMatrixMetric(double[] thresholds) { + this.thresholds = ExceptionsHelper.requireNonNull(thresholds, AT); + if (thresholds.length == 0) { + throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName() + + "] must have at least one value"); + } + for (double threshold : thresholds) { + if (threshold < 0 || threshold > 1.0) { + throw ExceptionsHelper.badRequestException("[" + getMetricName() + "." + AT.getPreferredName() + + "] values must be in [0.0, 1.0]"); + } + } + } + + protected AbstractConfusionMatrixMetric(StreamInput in) throws IOException { + this.thresholds = in.readDoubleArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDoubleArray(thresholds); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(AT.getPreferredName(), thresholds); + builder.endObject(); + return builder; + } + + @Override + public final List aggs(String actualField, List classInfos) { + List aggs = new ArrayList<>(); + for (double threshold : thresholds) { + aggs.addAll(aggsAt(actualField, classInfos, threshold)); + } + return aggs; + } + + protected abstract List aggsAt(String labelField, List classInfos, double threshold); + + protected enum Condition { + TP, FP, TN, FN; + } + + protected String aggName(ClassInfo classInfo, double threshold, Condition condition) { + return getMetricName() + "_" + classInfo.getName() + "_at_" + threshold + "_" + condition.name(); + } + + protected AggregationBuilder buildAgg(ClassInfo classInfo, double threshold, Condition condition) { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + switch (condition) { + case TP: + boolQuery.must(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold)); + break; + case FP: + boolQuery.mustNot(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).gte(threshold)); + break; + case TN: + boolQuery.mustNot(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold)); + break; + case FN: + boolQuery.must(classInfo.matchingQuery()); + boolQuery.must(QueryBuilders.rangeQuery(classInfo.getProbabilityField()).lt(threshold)); + break; + default: + throw new IllegalArgumentException("Unknown enum value: " + condition); + } + return AggregationBuilders.filter(aggName(classInfo, threshold, condition), boolQuery); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java new file mode 100644 index 0000000000000..228dac00bfb68 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.java @@ -0,0 +1,350 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.search.aggregations.metrics.Percentiles; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; +import java.util.stream.IntStream; + +/** + * Area under the curve (AUC) of the receiver operating characteristic (ROC). + * The ROC curve is a plot of the TPR (true positive rate) against + * the FPR (false positive rate) over a varying threshold. + * + * This particular implementation is making use of ES aggregations + * to calculate the curve. It then uses the trapezoidal rule to calculate + * the AUC. + * + * In particular, in order to calculate the ROC, we get percentiles of TP + * and FP against the predicted probability. We call those Rate-Threshold + * curves. We then scan ROC points from each Rate-Threshold curve against the + * other using interpolation. This gives us an approximation of the ROC curve + * that has the advantage of being efficient and resilient to some edge cases. + * + * When this is used for multi-class classification, it will calculate the ROC + * curve of each class versus the rest. + */ +public class AucRoc implements SoftClassificationMetric { + + public static final ParseField NAME = new ParseField("auc_roc"); + + public static final ParseField INCLUDE_CURVE = new ParseField("include_curve"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new AucRoc((Boolean) a[0])); + + static { + PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE); + } + + private static final String PERCENTILES = "percentiles"; + + public static AucRoc fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final boolean includeCurve; + + public AucRoc(Boolean includeCurve) { + this.includeCurve = includeCurve == null ? false : includeCurve; + } + + public AucRoc(StreamInput in) throws IOException { + this.includeCurve = in.readBoolean(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(includeCurve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INCLUDE_CURVE.getPreferredName(), includeCurve); + builder.endObject(); + return builder; + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AucRoc that = (AucRoc) o; + return Objects.equals(includeCurve, that.includeCurve); + } + + @Override + public int hashCode() { + return Objects.hash(includeCurve); + } + + @Override + public List aggs(String actualField, List classInfos) { + double[] percentiles = IntStream.range(1, 100).mapToDouble(v -> (double) v).toArray(); + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + AggregationBuilder percentilesForClassValueAgg = AggregationBuilders + .filter(evaluatedLabelAggName(classInfo), classInfo.matchingQuery()) + .subAggregation( + AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles)); + AggregationBuilder percentilesForRestAgg = AggregationBuilders + .filter(restLabelsAggName(classInfo), QueryBuilders.boolQuery().mustNot(classInfo.matchingQuery())) + .subAggregation( + AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(percentiles)); + aggs.add(percentilesForClassValueAgg); + aggs.add(percentilesForRestAgg); + } + return aggs; + } + + private String evaluatedLabelAggName(ClassInfo classInfo) { + return getMetricName() + "_" + classInfo.getName(); + } + + private String restLabelsAggName(ClassInfo classInfo) { + return getMetricName() + "_non_" + classInfo.getName(); + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + Filter classAgg = aggs.get(evaluatedLabelAggName(classInfo)); + Filter restAgg = aggs.get(restLabelsAggName(classInfo)); + double[] tpPercentiles = percentilesArray(classAgg.getAggregations().get(PERCENTILES), + "[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]"); + double[] fpPercentiles = percentilesArray(restAgg.getAggregations().get(PERCENTILES), + "[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]"); + List aucRocCurve = buildAucRocCurve(tpPercentiles, fpPercentiles); + double aucRocScore = calculateAucScore(aucRocCurve); + return new Result(aucRocScore, includeCurve ? aucRocCurve : Collections.emptyList()); + } + + private static double[] percentilesArray(Percentiles percentiles, String errorIfUndefined) { + double[] result = new double[99]; + percentiles.forEach(percentile -> { + if (Double.isNaN(percentile.getValue())) { + throw ExceptionsHelper.badRequestException(errorIfUndefined); + } + result[((int) percentile.getPercent()) - 1] = percentile.getValue(); + }); + return result; + } + + /** + * Visible for testing + */ + static List buildAucRocCurve(double[] tpPercentiles, double[] fpPercentiles) { + assert tpPercentiles.length == fpPercentiles.length; + assert tpPercentiles.length == 99; + + List aucRocCurve = new ArrayList<>(); + aucRocCurve.add(new AucRocPoint(0.0, 0.0, 1.0)); + aucRocCurve.add(new AucRocPoint(1.0, 1.0, 0.0)); + RateThresholdCurve tpCurve = new RateThresholdCurve(tpPercentiles, true); + RateThresholdCurve fpCurve = new RateThresholdCurve(fpPercentiles, false); + aucRocCurve.addAll(tpCurve.scanPoints(fpCurve)); + aucRocCurve.addAll(fpCurve.scanPoints(tpCurve)); + Collections.sort(aucRocCurve); + return aucRocCurve; + } + + /** + * Visible for testing + */ + static double calculateAucScore(List rocCurve) { + // Calculates AUC based on the trapezoid rule + double aucRoc = 0.0; + for (int i = 1; i < rocCurve.size(); i++) { + AucRocPoint left = rocCurve.get(i - 1); + AucRocPoint right = rocCurve.get(i); + aucRoc += (right.fpr - left.fpr) * (right.tpr + left.tpr) / 2; + } + return aucRoc; + } + + private static class RateThresholdCurve { + + private final double[] percentiles; + private final boolean isTp; + + private RateThresholdCurve(double[] percentiles, boolean isTp) { + this.percentiles = percentiles; + this.isTp = isTp; + } + + private double getRate(int index) { + return 1 - 0.01 * (index + 1); + } + + private double getThreshold(int index) { + return percentiles[index]; + } + + private double interpolateRate(double threshold) { + int binarySearchResult = Arrays.binarySearch(percentiles, threshold); + if (binarySearchResult >= 0) { + return getRate(binarySearchResult); + } else { + int right = (binarySearchResult * -1) -1; + int left = right - 1; + if (right >= percentiles.length) { + return 0.0; + } else if (left < 0) { + return 1.0; + } else { + double rightRate = getRate(right); + double leftRate = getRate(left); + return interpolate(threshold, percentiles[left], leftRate, percentiles[right], rightRate); + } + } + } + + private List scanPoints(RateThresholdCurve againstCurve) { + List points = new ArrayList<>(); + for (int index = 0; index < percentiles.length; index++) { + double rate = getRate(index); + double scannedThreshold = getThreshold(index); + double againstRate = againstCurve.interpolateRate(scannedThreshold); + AucRocPoint point; + if (isTp) { + point = new AucRocPoint(rate, againstRate, scannedThreshold); + } else { + point = new AucRocPoint(againstRate, rate, scannedThreshold); + } + points.add(point); + } + return points; + } + } + + public static final class AucRocPoint implements Comparable, ToXContentObject, Writeable { + double tpr; + double fpr; + double threshold; + + private AucRocPoint(double tpr, double fpr, double threshold) { + this.tpr = tpr; + this.fpr = fpr; + this.threshold = threshold; + } + + private AucRocPoint(StreamInput in) throws IOException { + this.tpr = in.readDouble(); + this.fpr = in.readDouble(); + this.threshold = in.readDouble(); + } + + @Override + public int compareTo(AucRocPoint o) { + return Comparator.comparingDouble((AucRocPoint p) -> p.threshold).reversed() + .thenComparing(p -> p.fpr) + .thenComparing(p -> p.tpr) + .compare(this, o); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(tpr); + out.writeDouble(fpr); + out.writeDouble(threshold); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("tpr", tpr); + builder.field("fpr", fpr); + builder.field("threshold", threshold); + builder.endObject(); + return builder; + } + + @Override + public String toString() { + return Strings.toString(this); + } + } + + private static double interpolate(double x, double x1, double y1, double x2, double y2) { + return y1 + (x - x1) * (y2 - y1) / (x2 - x1); + } + + public static class Result implements EvaluationMetricResult { + + private final double score; + private final List curve; + + public Result(double score, List curve) { + this.score = score; + this.curve = Objects.requireNonNull(curve); + } + + public Result(StreamInput in) throws IOException { + this.score = in.readDouble(); + this.curve = in.readList(AucRocPoint::new); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(score); + out.writeList(curve); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("score", score); + if (curve.isEmpty() == false) { + builder.field("curve", curve); + } + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java new file mode 100644 index 0000000000000..f594e7598fc20 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/BinarySoftClassification.java @@ -0,0 +1,212 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.Evaluation; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** + * Evaluation of binary soft classification methods, e.g. outlier detection. + * This is useful to evaluate problems where a model outputs a probability of whether + * a data frame row belongs to one of two groups. + */ +public class BinarySoftClassification implements Evaluation { + + public static final ParseField NAME = new ParseField("binary_soft_classification"); + + private static final ParseField ACTUAL_FIELD = new ParseField("actual_field"); + private static final ParseField PREDICTED_PROBABILITY_FIELD = new ParseField("predicted_probability_field"); + private static final ParseField METRICS = new ParseField("metrics"); + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME.getPreferredName(), a -> new BinarySoftClassification((String) a[0], (String) a[1], (List) a[2])); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_PROBABILITY_FIELD); + PARSER.declareNamedObjects(ConstructingObjectParser.optionalConstructorArg(), + (p, c, n) -> p.namedObject(SoftClassificationMetric.class, n, null), METRICS); + } + + public static BinarySoftClassification fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + /** + * The field where the actual class is marked up. + * The value of this field is assumed to either be 1 or 0, or true or false. + */ + private final String actualField; + + /** + * The field of the predicted probability in [0.0, 1.0]. + */ + private final String predictedProbabilityField; + + /** + * The list of metrics to calculate + */ + private final List metrics; + + public BinarySoftClassification(String actualField, String predictedProbabilityField, + @Nullable List metrics) { + this.actualField = ExceptionsHelper.requireNonNull(actualField, ACTUAL_FIELD); + this.predictedProbabilityField = ExceptionsHelper.requireNonNull(predictedProbabilityField, PREDICTED_PROBABILITY_FIELD); + this.metrics = initMetrics(metrics); + } + + private static List initMetrics(@Nullable List parsedMetrics) { + List metrics = parsedMetrics == null ? defaultMetrics() : parsedMetrics; + if (metrics.isEmpty()) { + throw ExceptionsHelper.badRequestException("[{}] must have one or more metrics", NAME.getPreferredName()); + } + Collections.sort(metrics, Comparator.comparing(SoftClassificationMetric::getMetricName)); + return metrics; + } + + private static List defaultMetrics() { + List defaultMetrics = new ArrayList<>(4); + defaultMetrics.add(new AucRoc(false)); + defaultMetrics.add(new Precision(Arrays.asList(0.25, 0.5, 0.75))); + defaultMetrics.add(new Recall(Arrays.asList(0.25, 0.5, 0.75))); + defaultMetrics.add(new ConfusionMatrix(Arrays.asList(0.25, 0.5, 0.75))); + return defaultMetrics; + } + + public BinarySoftClassification(StreamInput in) throws IOException { + this.actualField = in.readString(); + this.predictedProbabilityField = in.readString(); + this.metrics = in.readNamedWriteableList(SoftClassificationMetric.class); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(actualField); + out.writeString(predictedProbabilityField); + out.writeNamedWriteableList(metrics); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACTUAL_FIELD.getPreferredName(), actualField); + builder.field(PREDICTED_PROBABILITY_FIELD.getPreferredName(), predictedProbabilityField); + + builder.startObject(METRICS.getPreferredName()); + for (SoftClassificationMetric metric : metrics) { + builder.field(metric.getMetricName(), metric); + } + builder.endObject(); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BinarySoftClassification that = (BinarySoftClassification) o; + return Objects.equals(actualField, that.actualField) + && Objects.equals(predictedProbabilityField, that.predictedProbabilityField) + && Objects.equals(metrics, that.metrics); + } + + @Override + public int hashCode() { + return Objects.hash(actualField, predictedProbabilityField, metrics); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public SearchSourceBuilder buildSearch() { + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.size(0); + searchSourceBuilder.query(buildQuery()); + for (SoftClassificationMetric metric : metrics) { + List aggs = metric.aggs(actualField, Collections.singletonList(new BinaryClassInfo())); + aggs.forEach(searchSourceBuilder::aggregation); + } + return searchSourceBuilder; + } + + private QueryBuilder buildQuery() { + BoolQueryBuilder boolQuery = QueryBuilders.boolQuery(); + boolQuery.filter(QueryBuilders.existsQuery(actualField)); + boolQuery.filter(QueryBuilders.existsQuery(predictedProbabilityField)); + return boolQuery; + } + + @Override + public void evaluate(SearchResponse searchResponse, ActionListener> listener) { + if (searchResponse.getHits().getTotalHits().value == 0) { + listener.onFailure(ExceptionsHelper.badRequestException("No documents found containing both [{}, {}] fields", actualField, + predictedProbabilityField)); + return; + } + + List results = new ArrayList<>(); + Aggregations aggs = searchResponse.getAggregations(); + BinaryClassInfo binaryClassInfo = new BinaryClassInfo(); + for (SoftClassificationMetric metric : metrics) { + results.add(metric.evaluate(binaryClassInfo, aggs)); + } + listener.onResponse(results); + } + + private class BinaryClassInfo implements SoftClassificationMetric.ClassInfo { + + private QueryBuilder matchingQuery = QueryBuilders.queryStringQuery(actualField + ": (1 OR true)"); + + @Override + public String getName() { + return String.valueOf(true); + } + + @Override + public QueryBuilder matchingQuery() { + return matchingQuery; + } + + @Override + public String getProbabilityField() { + return predictedProbabilityField; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java new file mode 100644 index 0000000000000..54f245962d515 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ConfusionMatrix.java @@ -0,0 +1,163 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class ConfusionMatrix extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("confusion_matrix"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new ConfusionMatrix((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static ConfusionMatrix fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public ConfusionMatrix(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public ConfusionMatrix(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfusionMatrix that = (ConfusionMatrix) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String labelField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FP)); + aggs.add(buildAgg(classInfo, threshold, Condition.TN)); + aggs.add(buildAgg(classInfo, threshold, Condition.FN)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + long[] tp = new long[thresholds.length]; + long[] fp = new long[thresholds.length]; + long[] tn = new long[thresholds.length]; + long[] fn = new long[thresholds.length]; + for (int i = 0; i < thresholds.length; i++) { + Filter tpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TP)); + Filter fpAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FP)); + Filter tnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.TN)); + Filter fnAgg = aggs.get(aggName(classInfo, thresholds[i], Condition.FN)); + tp[i] = tpAgg.getDocCount(); + fp[i] = fpAgg.getDocCount(); + tn[i] = tnAgg.getDocCount(); + fn[i] = fnAgg.getDocCount(); + } + return new Result(thresholds, tp, fp, tn, fn); + } + + public static class Result implements EvaluationMetricResult { + + private final double[] thresholds; + private final long[] tp; + private final long[] fp; + private final long[] tn; + private final long[] fn; + + public Result(double[] thresholds, long[] tp, long[] fp, long[] tn, long[] fn) { + assert thresholds.length == tp.length; + assert thresholds.length == fp.length; + assert thresholds.length == tn.length; + assert thresholds.length == fn.length; + this.thresholds = thresholds; + this.tp = tp; + this.fp = fp; + this.tn = tn; + this.fn = fn; + } + + public Result(StreamInput in) throws IOException { + this.thresholds = in.readDoubleArray(); + this.tp = in.readLongArray(); + this.fp = in.readLongArray(); + this.tn = in.readLongArray(); + this.fn = in.readLongArray(); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getName() { + return NAME.getPreferredName(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDoubleArray(thresholds); + out.writeLongArray(tp); + out.writeLongArray(fp); + out.writeLongArray(tn); + out.writeLongArray(fn); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + for (int i = 0; i < thresholds.length; i++) { + builder.startObject(String.valueOf(thresholds[i])); + builder.field("tp", tp[i]); + builder.field("fp", fp[i]); + builder.field("tn", tn[i]); + builder.field("fn", fn[i]); + builder.endObject(); + } + builder.endObject(); + return builder; + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java new file mode 100644 index 0000000000000..d38a52bb203e8 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Precision.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class Precision extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("precision"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new Precision((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static Precision fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public Precision(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public Precision(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Precision that = (Precision) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String labelField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo : classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FP)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + double[] precisions = new double[thresholds.length]; + for (int i = 0; i < precisions.length; i++) { + double threshold = thresholds[i]; + Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); + Filter fpAgg = aggs.get(aggName(classInfo, threshold, Condition.FP)); + long tp = tpAgg.getDocCount(); + long fp = fpAgg.getDocCount(); + precisions[i] = tp + fp == 0 ? 0.0 : (double) tp / (tp + fp); + } + return new ScoreByThresholdResult(NAME.getPreferredName(), thresholds, precisions); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java new file mode 100644 index 0000000000000..5c4ab57241d95 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/Recall.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.bucket.filter.Filter; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class Recall extends AbstractConfusionMatrixMetric { + + public static final ParseField NAME = new ParseField("recall"); + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), + a -> new Recall((List) a[0])); + + static { + PARSER.declareDoubleArray(ConstructingObjectParser.constructorArg(), AT); + } + + public static Recall fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + public Recall(List at) { + super(at.stream().mapToDouble(Double::doubleValue).toArray()); + } + + public Recall(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME.getPreferredName(); + } + + @Override + public String getMetricName() { + return NAME.getPreferredName(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Recall that = (Recall) o; + return Arrays.equals(thresholds, that.thresholds); + } + + @Override + public int hashCode() { + return Arrays.hashCode(thresholds); + } + + @Override + protected List aggsAt(String actualField, List classInfos, double threshold) { + List aggs = new ArrayList<>(); + for (ClassInfo classInfo: classInfos) { + aggs.add(buildAgg(classInfo, threshold, Condition.TP)); + aggs.add(buildAgg(classInfo, threshold, Condition.FN)); + } + return aggs; + } + + @Override + public EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs) { + double[] recalls = new double[thresholds.length]; + for (int i = 0; i < recalls.length; i++) { + double threshold = thresholds[i]; + Filter tpAgg = aggs.get(aggName(classInfo, threshold, Condition.TP)); + Filter fnAgg =aggs.get(aggName(classInfo, threshold, Condition.FN)); + long tp = tpAgg.getDocCount(); + long fn = fnAgg.getDocCount(); + recalls[i] = tp + fn == 0 ? 0.0 : (double) tp / (tp + fn); + } + return new ScoreByThresholdResult(NAME.getPreferredName(), thresholds, recalls); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java new file mode 100644 index 0000000000000..bd6b6e7db25a1 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/ScoreByThresholdResult.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.io.IOException; +import java.util.Objects; + +public class ScoreByThresholdResult implements EvaluationMetricResult { + + public static final String NAME = "score_by_threshold_result"; + + private final String name; + private final double[] thresholds; + private final double[] scores; + + public ScoreByThresholdResult(String name, double[] thresholds, double[] scores) { + assert thresholds.length == scores.length; + this.name = Objects.requireNonNull(name); + this.thresholds = thresholds; + this.scores = scores; + } + + public ScoreByThresholdResult(StreamInput in) throws IOException { + this.name = in.readString(); + this.thresholds = in.readDoubleArray(); + this.scores = in.readDoubleArray(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return name; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + out.writeDoubleArray(thresholds); + out.writeDoubleArray(scores); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + for (int i = 0; i < thresholds.length; i++) { + builder.field(String.valueOf(thresholds[i]), scores[i]); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java new file mode 100644 index 0000000000000..dfb256e9b52f2 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/SoftClassificationMetric.java @@ -0,0 +1,60 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult; + +import java.util.List; + +public interface SoftClassificationMetric extends ToXContentObject, NamedWriteable { + + /** + * The information of a specific class + */ + interface ClassInfo { + + /** + * Returns the class name + */ + String getName(); + + /** + * Returns a query that matches documents of the class + */ + QueryBuilder matchingQuery(); + + /** + * Returns the field that has the probability to be of the class + */ + String getProbabilityField(); + } + + /** + * Returns the name of the metric (which may differ to the writeable name) + */ + String getMetricName(); + + /** + * Builds the aggregation that collect required data to compute the metric + * @param actualField the field that stores the actual class + * @param classInfos the information of each class to compute the metric for + * @return the aggregations required to compute the metric + */ + List aggs(String actualField, List classInfos); + + /** + * Calculates the metric result for a given class + * @param classInfo the class to calculate the metric for + * @param aggs the aggregations + * @return the metric result + */ + EvaluationMetricResult evaluate(ClassInfo classInfo, Aggregations aggs); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 22eb0dc357bed..417184f8a752b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -50,6 +50,10 @@ public final class Messages { "Datafeed frequency [{0}] must be a multiple of the aggregation interval [{1}]"; public static final String DATAFEED_ID_ALREADY_TAKEN = "A datafeed with id [{0}] already exists"; + public static final String DATA_FRAME_ANALYTICS_BAD_QUERY_FORMAT = "Data Frame Analytics config query is not parsable"; + public static final String DATA_FRAME_ANALYTICS_BAD_FIELD_FILTER = + "No compatible fields could be detected in index [{0}] with name [{1}]"; + public static final String FILTER_CANNOT_DELETE = "Cannot delete filter [{0}] currently used by jobs {1}"; public static final String FILTER_CONTAINS_TOO_MANY_ITEMS = "Filter [{0}] contains too many items; up to [{1}] items are allowed"; public static final String FILTER_NOT_FOUND = "No filter with id [{0}] exists"; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index e3f9c3c4381f7..327950c29954b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -26,6 +26,10 @@ import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -144,6 +148,7 @@ public static XContentBuilder configMapping() throws IOException { addJobConfigFields(builder); addDatafeedConfigFields(builder); + addDataFrameAnalyticsFields(builder); builder.endObject() .endObject() @@ -386,6 +391,52 @@ public static void addDatafeedConfigFields(XContentBuilder builder) throws IOExc .endObject(); } + public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws IOException { + builder.startObject(DataFrameAnalyticsConfig.ID.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsConfig.SOURCE.getPreferredName()) + .startObject(PROPERTIES) + .startObject(DataFrameAnalyticsSource.INDEX.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsSource.QUERY.getPreferredName()) + .field(ENABLED, false) + .endObject() + .endObject() + .endObject() + .startObject(DataFrameAnalyticsConfig.DEST.getPreferredName()) + .startObject(PROPERTIES) + .startObject(DataFrameAnalyticsDest.INDEX.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .endObject() + .endObject() + .startObject(DataFrameAnalyticsConfig.ANALYZED_FIELDS.getPreferredName()) + .field(ENABLED, false) + .endObject() + .startObject(DataFrameAnalyticsConfig.ANALYSIS.getPreferredName()) + .startObject(PROPERTIES) + .startObject(OutlierDetection.NAME.getPreferredName()) + .startObject(PROPERTIES) + .startObject(OutlierDetection.N_NEIGHBORS.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() + .startObject(OutlierDetection.METHOD.getPreferredName()) + .field(TYPE, KEYWORD) + .endObject() + .startObject(OutlierDetection.MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName()) + .field(TYPE, DOUBLE) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + } + /** * Creates a default mapping which has a dynamic template that * treats all dynamically added fields as keywords. This is needed diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index 58785235db1db..5255fe5e8791a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -9,6 +9,10 @@ import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig; import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; import org.elasticsearch.xpack.core.ml.job.config.DataDescription; @@ -268,6 +272,20 @@ public final class ReservedFieldNames { ChunkingConfig.MODE_FIELD.getPreferredName(), ChunkingConfig.TIME_SPAN_FIELD.getPreferredName(), + DataFrameAnalyticsConfig.ID.getPreferredName(), + DataFrameAnalyticsConfig.SOURCE.getPreferredName(), + DataFrameAnalyticsConfig.DEST.getPreferredName(), + DataFrameAnalyticsConfig.ANALYSIS.getPreferredName(), + DataFrameAnalyticsConfig.ANALYZED_FIELDS.getPreferredName(), + DataFrameAnalyticsDest.INDEX.getPreferredName(), + DataFrameAnalyticsDest.RESULTS_FIELD.getPreferredName(), + DataFrameAnalyticsSource.INDEX.getPreferredName(), + DataFrameAnalyticsSource.QUERY.getPreferredName(), + OutlierDetection.NAME.getPreferredName(), + OutlierDetection.N_NEIGHBORS.getPreferredName(), + OutlierDetection.METHOD.getPreferredName(), + OutlierDetection.MINIMUM_SCORE_TO_WRITE_FEATURE_INFLUENCE.getPreferredName(), + ElasticsearchMappings.CONFIG_TYPE, GetResult._ID, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java index b66fd948a5a83..2d4c636172eca 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/process/writer/RecordWriter.java @@ -10,7 +10,7 @@ /** * Interface for classes that write arrays of strings to the - * Ml analytics processes. + * Ml data frame analytics processes. */ public interface RecordWriter { /** diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index 47c0d4f64f96f..320eace983590 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -10,6 +10,7 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.common.ParseField; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.xpack.core.ml.job.messages.Messages; @@ -34,6 +35,14 @@ public static ResourceAlreadyExistsException datafeedAlreadyExists(String datafe return new ResourceAlreadyExistsException(Messages.getMessage(Messages.DATAFEED_ID_ALREADY_TAKEN, datafeedId)); } + public static ResourceNotFoundException missingDataFrameAnalytics(String id) { + return new ResourceNotFoundException("No known data frame analytics with id [{}]", id); + } + + public static ResourceAlreadyExistsException dataFrameAnalyticsAlreadyExists(String id) { + return new ResourceAlreadyExistsException("A data frame analytics with id [{}] already exists", id); + } + public static ElasticsearchException serverError(String msg) { return new ElasticsearchException(msg); } @@ -86,4 +95,8 @@ public static T requireNonNull(T obj, String paramName) { } return obj; } + + public static T requireNonNull(T obj, ParseField paramName) { + return requireNonNull(obj, paramName.getPreferredName()); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java similarity index 81% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java index 755c5a3526d01..3fe0ba70331a0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/QueryProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/QueryProvider.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.core.ml.datafeed; +package org.elasticsearch.xpack.core.ml.utils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -17,9 +17,6 @@ import org.elasticsearch.index.query.MatchAllQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.xpack.core.ml.job.messages.Messages; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; import java.util.Collections; @@ -27,22 +24,22 @@ import java.util.Map; import java.util.Objects; -class QueryProvider implements Writeable, ToXContentObject { +public class QueryProvider implements Writeable, ToXContentObject { - private static final Logger logger = LogManager.getLogger(AggProvider.class); + private static final Logger logger = LogManager.getLogger(QueryProvider.class); private Exception parsingException; private QueryBuilder parsedQuery; private Map query; - static QueryProvider defaultQuery() { + public static QueryProvider defaultQuery() { return new QueryProvider( Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()), QueryBuilders.matchAllQuery(), null); } - static QueryProvider fromXContent(XContentParser parser, boolean lenient) throws IOException { + public static QueryProvider fromXContent(XContentParser parser, boolean lenient, String failureMessage) throws IOException { Map query = parser.mapOrdered(); QueryBuilder parsedQuery = null; Exception exception = null; @@ -54,15 +51,15 @@ static QueryProvider fromXContent(XContentParser parser, boolean lenient) throws } exception = ex; if (lenient) { - logger.warn(Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT, ex); + logger.warn(failureMessage, ex); } else { - throw ExceptionsHelper.badRequestException(Messages.DATAFEED_CONFIG_QUERY_BAD_FORMAT, ex); + throw ExceptionsHelper.badRequestException(failureMessage, ex); } } return new QueryProvider(query, parsedQuery, exception); } - static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOException { + public static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOException { return parsedQuery == null ? null : new QueryProvider( @@ -71,7 +68,7 @@ static QueryProvider fromParsedQuery(QueryBuilder parsedQuery) throws IOExceptio null); } - static QueryProvider fromStream(StreamInput in) throws IOException { + public static QueryProvider fromStream(StreamInput in) throws IOException { return new QueryProvider(in.readMap(), in.readOptionalNamedWriteable(QueryBuilder.class), in.readException()); } @@ -81,7 +78,7 @@ static QueryProvider fromStream(StreamInput in) throws IOException { this.parsingException = parsingException; } - QueryProvider(QueryProvider other) { + public QueryProvider(QueryProvider other) { this(other.query, other.parsedQuery, other.parsingException); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java index c829914c78fc0..494b1c76c5de4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractSerializingDataFrameTestCase.java @@ -13,6 +13,10 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.junit.Before; import java.util.List; @@ -30,7 +34,11 @@ public void registerNamedObjects() { SearchModule searchModule = new SearchModule(Settings.EMPTY, emptyList()); List namedWriteables = searchModule.getNamedWriteables(); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java index c8755feed073a..85a24da5a1a24 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/AbstractWireSerializingDataFrameTestCase.java @@ -12,6 +12,10 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; +import org.elasticsearch.xpack.core.dataframe.transforms.SyncConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; import org.junit.Before; import java.util.List; @@ -30,7 +34,11 @@ public void registerNamedObjects() { SearchModule searchModule = new SearchModule(Settings.EMPTY, emptyList()); List namedWriteables = searchModule.getNamedWriteables(); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); + List namedXContents = searchModule.getNamedXContents(); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java index c3a921a90d26b..ea6f2a47f4692 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/action/PreviewDataFrameTransformActionRequestTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.dataframe.action.PreviewDataFrameTransformAction.Request; import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfig; +import org.elasticsearch.xpack.core.dataframe.transforms.DataFrameTransformConfigTests; import org.elasticsearch.xpack.core.dataframe.transforms.DestConfig; import org.elasticsearch.xpack.core.dataframe.transforms.pivot.PivotConfigTests; @@ -39,9 +40,14 @@ protected boolean supportsUnknownFields() { @Override protected Request createTestInstance() { - DataFrameTransformConfig config = new DataFrameTransformConfig("transform-preview", randomSourceConfig(), + DataFrameTransformConfig config = new DataFrameTransformConfig( + "transform-preview", + randomSourceConfig(), new DestConfig("unused-transform-preview-index", null), - null, PivotConfigTests.randomPivotConfig(), null); + randomBoolean() ? DataFrameTransformConfigTests.randomSyncConfig() : null, + null, + PivotConfigTests.randomPivotConfig(), + null); return new Request(config); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java index 79a50c30a276e..7dd501160ba00 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/AbstractSerializingDataFrameTestCase.java @@ -19,6 +19,7 @@ import org.elasticsearch.search.aggregations.BaseAggregationBuilder; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.dataframe.DataFrameField; +import org.elasticsearch.xpack.core.dataframe.DataFrameNamedXContentProvider; import org.junit.Before; import java.util.Collections; @@ -48,12 +49,15 @@ public void registerAggregationNamedObjects() throws Exception { MockDeprecatedQueryBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(AggregationBuilder.class, MockDeprecatedAggregationBuilder.NAME, MockDeprecatedAggregationBuilder::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SyncConfig.class, DataFrameField.TIME_BASED_SYNC.getPreferredName(), + TimeSyncConfig::new)); List namedXContents = searchModule.getNamedXContents(); namedXContents.add(new NamedXContentRegistry.Entry(QueryBuilder.class, new ParseField(MockDeprecatedQueryBuilder.NAME), (p, c) -> MockDeprecatedQueryBuilder.fromXContent(p))); namedXContents.add(new NamedXContentRegistry.Entry(BaseAggregationBuilder.class, new ParseField(MockDeprecatedAggregationBuilder.NAME), (p, c) -> MockDeprecatedAggregationBuilder.fromXContent(p))); + namedXContents.addAll(new DataFrameNamedXContentProvider().getNamedXContentParsers()); namedWriteableRegistry = new NamedWriteableRegistry(namedWriteables); namedXContentRegistry = new NamedXContentRegistry(namedXContents); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java index 907c8eb98e69f..dd5b5c9ff8841 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/DataFrameTransformConfigTests.java @@ -46,6 +46,7 @@ public static DataFrameTransformConfig randomDataFrameTransformConfigWithoutHead return new DataFrameTransformConfig(id, randomSourceConfig(), randomDestConfig(), + randomBoolean() ? null : randomSyncConfig(), null, PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000), @@ -57,6 +58,7 @@ public static DataFrameTransformConfig randomDataFrameTransformConfig(String id) return new DataFrameTransformConfig(id, randomSourceConfig(), randomDestConfig(), + randomBoolean() ? null : randomSyncConfig(), randomHeaders(), PivotConfigTests.randomPivotConfig(), randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000), @@ -66,13 +68,17 @@ public static DataFrameTransformConfig randomDataFrameTransformConfig(String id) public static DataFrameTransformConfig randomInvalidDataFrameTransformConfig() { if (randomBoolean()) { - return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomInvalidSourceConfig(), - randomDestConfig(), randomHeaders(), PivotConfigTests.randomPivotConfig(), - randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100)); + return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomInvalidSourceConfig(), randomDestConfig(), + randomBoolean() ? randomSyncConfig() : null, randomHeaders(), PivotConfigTests.randomPivotConfig(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); } // else - return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), - randomDestConfig(), randomHeaders(), PivotConfigTests.randomInvalidPivotConfig(), - randomBoolean() ? null : randomAlphaOfLengthBetween(1, 100)); + return new DataFrameTransformConfig(randomAlphaOfLengthBetween(1, 10), randomSourceConfig(), randomDestConfig(), + randomBoolean() ? randomSyncConfig() : null, randomHeaders(), PivotConfigTests.randomInvalidPivotConfig(), + randomBoolean() ? null : randomAlphaOfLengthBetween(1, 1000)); + } + + public static SyncConfig randomSyncConfig() { + return TimeSyncConfigTests.randomTimeSyncConfig(); } @Before @@ -223,11 +229,11 @@ public void testXContentForInternalStorage() throws IOException { public void testMaxLengthDescription() { IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> new DataFrameTransformConfig("id", - randomSourceConfig(), randomDestConfig(), null, PivotConfigTests.randomPivotConfig(), randomAlphaOfLength(1001))); + randomSourceConfig(), randomDestConfig(), null, null, PivotConfigTests.randomPivotConfig(), randomAlphaOfLength(1001))); assertThat(exception.getMessage(), equalTo("[description] must be less than 1000 characters in length.")); String description = randomAlphaOfLength(1000); DataFrameTransformConfig config = new DataFrameTransformConfig("id", - randomSourceConfig(), randomDestConfig(), null, PivotConfigTests.randomPivotConfig(), description); + randomSourceConfig(), randomDestConfig(), null, null, PivotConfigTests.randomPivotConfig(), description); assertThat(description, equalTo(config.getDescription())); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java new file mode 100644 index 0000000000000..763e13e77aee0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/dataframe/transforms/TimeSyncConfigTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.dataframe.transforms; + +import org.elasticsearch.common.io.stream.Writeable.Reader; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.dataframe.transforms.TimeSyncConfig; + +import java.io.IOException; + +public class TimeSyncConfigTests extends AbstractSerializingTestCase { + + public static TimeSyncConfig randomTimeSyncConfig() { + return new TimeSyncConfig(randomAlphaOfLengthBetween(1, 10), new TimeValue(randomNonNegativeLong())); + } + + @Override + protected TimeSyncConfig doParseInstance(XContentParser parser) throws IOException { + return TimeSyncConfig.fromXContent(parser, false); + } + + @Override + protected TimeSyncConfig createTestInstance() { + return randomTimeSyncConfig(); + } + + @Override + protected Reader instanceReader() { + return TimeSyncConfig::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java index 3afe76b8b171f..f2015b1a2bbb5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MlTasksTests.java @@ -22,6 +22,7 @@ import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; public class MlTasksTests extends ESTestCase { public void testGetJobState() { @@ -161,4 +162,10 @@ public void testUnallocatedDatafeedIds() { assertThat(MlTasks.unallocatedDatafeedIds(tasksBuilder.build(), nodes), containsInAnyOrder("datafeed_without_assignment", "datafeed_without_node")); } + + public void testDataFrameAnalyticsTaskIds() { + String taskId = MlTasks.dataFrameAnalyticsTaskId("foo"); + assertThat(taskId, equalTo("data_frame_analytics-foo")); + assertThat(MlTasks.dataFrameAnalyticsIdFromTaskId(taskId), equalTo("foo")); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java new file mode 100644 index 0000000000000..e899b7e6642da --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/EvaluateDataFrameActionRequestTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractStreamableXContentTestCase; +import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction.Request; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.BinarySoftClassificationTests; + +import java.util.ArrayList; +import java.util.List; + +public class EvaluateDataFrameActionRequestTests extends AbstractStreamableXContentTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(new MlEvaluationNamedXContentProvider().getNamedWriteables()); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + return new NamedXContentRegistry(new MlEvaluationNamedXContentProvider().getNamedXContentParsers()); + } + + @Override + protected Request createTestInstance() { + Request request = new Request(); + int indicesCount = randomIntBetween(1, 5); + List indices = new ArrayList<>(indicesCount); + for (int i = 0; i < indicesCount; i++) { + indices.add(randomAlphaOfLength(10)); + } + request.setIndices(indices); + request.setEvaluation(BinarySoftClassificationTests.createRandom()); + return request; + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected Request createBlankInstance() { + return new Request(); + } + + @Override + protected Request doParseInstance(XContentParser parser) { + return Request.parseRequest(parser); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..8a7b6717abd92 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractStreamableTestCase; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Response; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class GetDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected Response createTestInstance() { + int listSize = randomInt(10); + List analytics = new ArrayList<>(listSize); + for (int j = 0; j < listSize; j++) { + analytics.add(DataFrameAnalyticsConfigTests.createRandom(DataFrameAnalyticsConfigTests.randomValidId())); + } + return new Response(new QueryPage<>(analytics, analytics.size(), Response.RESULTS_FIELD)); + } + + @Override + protected Response createBlankInstance() { + return new Response(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..438474076c910 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsRequestTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction.Request; + +public class GetDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(); + request.setResourceId(randomAlphaOfLength(20)); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java new file mode 100644 index 0000000000000..e01618520f5a8 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsActionResponseTests.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.QueryPage; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Response; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; + +import java.util.ArrayList; +import java.util.List; + +public class GetDataFrameAnalyticsStatsActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + int listSize = randomInt(10); + List analytics = new ArrayList<>(listSize); + for (int j = 0; j < listSize; j++) { + Integer progressPercentage = randomBoolean() ? null : randomIntBetween(0, 100); + Response.Stats stats = new Response.Stats(DataFrameAnalyticsConfigTests.randomValidId(), + randomFrom(DataFrameAnalyticsState.values()), progressPercentage, null, randomAlphaOfLength(20)); + analytics.add(stats); + } + return new Response(new QueryPage<>(analytics, analytics.size(), GetDataFrameAnalyticsAction.Response.RESULTS_FIELD)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java new file mode 100644 index 0000000000000..918d04873ef2c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsRequestTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; +import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction.Request; + +public class GetDataFrameAnalyticsStatsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100))); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java new file mode 100644 index 0000000000000..1e5416d5a5dce --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionRequestTests.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractStreamableXContentTestCase; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Request; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class PutDataFrameAnalyticsActionRequestTests extends AbstractStreamableXContentTestCase { + + private String id; + + @Before + public void setUpId() { + id = DataFrameAnalyticsConfigTests.randomValidId(); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected Request createTestInstance() { + return new Request(DataFrameAnalyticsConfigTests.createRandom(id)); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } + + @Override + protected Request createBlankInstance() { + return new Request(); + } + + @Override + protected Request doParseInstance(XContentParser parser) { + return Request.parseRequest(id, parser); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..d323505828e42 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.test.AbstractStreamableTestCase; +import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction.Response; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class PutDataFrameAnalyticsActionResponseTests extends AbstractStreamableTestCase { + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected Response createTestInstance() { + return new Response(DataFrameAnalyticsConfigTests.createRandom(DataFrameAnalyticsConfigTests.randomValidId())); + } + + @Override + protected Response createBlankInstance() { + return new Response(); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..a3db5833b820d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsRequestTests.java @@ -0,0 +1,28 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction.Request; + +public class StartDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + if (randomBoolean()) { + request.setTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java new file mode 100644 index 0000000000000..d06d8cb1a1860 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsActionResponseTests.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction.Response; + +public class StopDataFrameAnalyticsActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + protected Response createTestInstance() { + return new Response(randomBoolean()); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java new file mode 100644 index 0000000000000..9c61164c5f02a --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StopDataFrameAnalyticsRequestTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.StopDataFrameAnalyticsAction.Request; + +import java.util.HashSet; +import java.util.Set; + +public class StopDataFrameAnalyticsRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + Request request = new Request(randomAlphaOfLength(20)); + if (randomBoolean()) { + request.setTimeout(TimeValue.timeValueMillis(randomNonNegativeLong())); + } + if (randomBoolean()) { + request.setAllowNoMatch(randomBoolean()); + } + int expandedIdsCount = randomIntBetween(0, 10); + Set expandedIds = new HashSet<>(); + for (int i = 0; i < expandedIdsCount; i++) { + expandedIds.add(randomAlphaOfLength(20)); + } + request.setExpandedIds(expandedIds); + return request; + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java index 9df5647a62ab9..15ba85a8edbf1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedConfigTests.java @@ -46,6 +46,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig.Mode; import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; @@ -57,7 +58,7 @@ import java.util.List; import java.util.Map; -import static org.elasticsearch.xpack.core.ml.datafeed.QueryProviderTests.createRandomValidQueryProvider; +import static org.elasticsearch.xpack.core.ml.utils.QueryProviderTests.createRandomValidQueryProvider; import static org.elasticsearch.xpack.core.ml.job.messages.Messages.DATAFEED_AGGREGATIONS_INTERVAL_MUST_BE_GREATER_THAN_ZERO; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java index 55ee9826a2d80..35bff34f93803 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java @@ -38,6 +38,7 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.datafeed.ChunkingConfig.Mode; import org.elasticsearch.xpack.core.ml.job.config.JobTests; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.core.ml.utils.XContentObjectTransformer; import java.io.IOException; @@ -47,7 +48,7 @@ import java.util.List; import static org.elasticsearch.xpack.core.ml.datafeed.AggProviderTests.createRandomValidAggProvider; -import static org.elasticsearch.xpack.core.ml.datafeed.QueryProviderTests.createRandomValidQueryProvider; +import static org.elasticsearch.xpack.core.ml.utils.QueryProviderTests.createRandomValidQueryProvider; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java new file mode 100644 index 0000000000000..dd9b229913aa9 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -0,0 +1,251 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import com.carrotsearch.randomizedtesting.generators.CodepointSetGenerator; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.fetch.subphase.FetchSourceContext; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.MlDataFrameAnalysisNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetectionTests; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.startsWith; + +public class DataFrameAnalyticsConfigTests extends AbstractSerializingTestCase { + + @Override + protected DataFrameAnalyticsConfig doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build(); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List namedWriteables = new ArrayList<>(); + namedWriteables.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedWriteables()); + namedWriteables.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedWriteables()); + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + } + + @Override + protected DataFrameAnalyticsConfig createTestInstance() { + return createRandom(randomValidId()); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataFrameAnalyticsConfig::new; + } + + public static DataFrameAnalyticsConfig createRandom(String id) { + return createRandomBuilder(id).build(); + } + + public static DataFrameAnalyticsConfig.Builder createRandomBuilder(String id) { + DataFrameAnalyticsSource source = DataFrameAnalyticsSourceTests.createRandom(); + DataFrameAnalyticsDest dest = DataFrameAnalyticsDestTests.createRandom(); + DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder() + .setId(id) + .setAnalysis(OutlierDetectionTests.createRandom()) + .setSource(source) + .setDest(dest); + if (randomBoolean()) { + builder.setAnalyzedFields(new FetchSourceContext(true, + generateRandomStringArray(10, 10, false, false), + generateRandomStringArray(10, 10, false, false))); + } + if (randomBoolean()) { + builder.setModelMemoryLimit(new ByteSizeValue(randomIntBetween(1, 16), randomFrom(ByteSizeUnit.MB, ByteSizeUnit.GB))); + } + return builder; + } + + public static String randomValidId() { + CodepointSetGenerator generator = new CodepointSetGenerator("abcdefghijklmnopqrstuvwxyz".toCharArray()); + return generator.ofCodePointsLength(random(), 10, 10); + } + + private static final String ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS = "{\n" + + " \"id\": \"old-data-frame\",\n" + + //query:match:type stopped being supported in 6.x + " \"source\": {\"index\":\"my-index\", \"query\": {\"match\" : {\"query\":\"fieldName\", \"type\": \"phrase\"}}},\n" + + " \"dest\": {\"index\":\"dest-index\"},\n" + + " \"analysis\": {\"outlier_detection\": {\"n_neighbors\": 10}}\n" + + "}"; + + private static final String MODERN_QUERY_DATA_FRAME_ANALYTICS = "{\n" + + " \"id\": \"data-frame\",\n" + + // match_all if parsed, adds default values in the options + " \"source\": {\"index\":\"my-index\", \"query\": {\"match_all\" : {}}},\n" + + " \"dest\": {\"index\":\"dest-index\"},\n" + + " \"analysis\": {\"outlier_detection\": {\"n_neighbors\": 10}}\n" + + "}"; + + public void testQueryConfigStoresUserInputOnly() throws IOException { + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + MODERN_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(config.getSource().getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); + } + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + MODERN_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build(); + assertThat(config.getSource().getQuery(), equalTo(Collections.singletonMap(MatchAllQueryBuilder.NAME, Collections.emptyMap()))); + } + } + + public void testPastQueryConfigParse() throws IOException { + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { + + DataFrameAnalyticsConfig config = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> config.getSource().getParsedQuery()); + assertEquals("[match] query doesn't support multiple fields, found [query] and [type]", e.getMessage()); + } + + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + ANACHRONISTIC_QUERY_DATA_FRAME_ANALYTICS)) { + + XContentParseException e = expectThrows(XContentParseException.class, + () -> DataFrameAnalyticsConfig.STRICT_PARSER.apply(parser, null).build()); + assertThat(e.getMessage(), containsString("[data_frame_analytics_config] failed to parse field [source]")); + } + } + + public void testToXContentForInternalStorage() throws IOException { + DataFrameAnalyticsConfig.Builder builder = createRandomBuilder("foo"); + + // headers are only persisted to cluster state + Map headers = new HashMap<>(); + headers.put("header-name", "header-value"); + builder.setHeaders(headers); + DataFrameAnalyticsConfig config = builder.build(); + + ToXContent.MapParams params = new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); + + BytesReference forClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, params, false); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, forClusterstateXContent.streamInput()); + + DataFrameAnalyticsConfig parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(parsedConfig.getHeaders(), hasEntry("header-name", "header-value")); + + // headers are not written without the FOR_INTERNAL_STORAGE param + BytesReference nonClusterstateXContent = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); + parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry(), LoggingDeprecationHandler.INSTANCE, nonClusterstateXContent.streamInput()); + + parsedConfig = DataFrameAnalyticsConfig.LENIENT_PARSER.apply(parser, null).build(); + assertThat(parsedConfig.getHeaders().entrySet(), hasSize(0)); + } + + public void testInvalidModelMemoryLimits() { + + DataFrameAnalyticsConfig.Builder builder = new DataFrameAnalyticsConfig.Builder(); + + // All these are different ways of specifying a limit that is lower than the minimum + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(1048575, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(-1, ByteSizeUnit.BYTES)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(1023, ByteSizeUnit.KB)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.KB)))); + assertTooSmall(expectThrows(IllegalArgumentException.class, + () -> builder.setModelMemoryLimit(new ByteSizeValue(0, ByteSizeUnit.MB)))); + } + + public void testNoMemoryCapping() { + + DataFrameAnalyticsConfig uncapped = createRandom("foo"); + + ByteSizeValue unlimited = randomBoolean() ? null : ByteSizeValue.ZERO; + assertThat(uncapped.getModelMemoryLimit(), + equalTo(new DataFrameAnalyticsConfig.Builder(uncapped, unlimited).build().getModelMemoryLimit())); + } + + public void testMemoryCapping() { + + DataFrameAnalyticsConfig defaultLimitConfig = createRandomBuilder("foo").setModelMemoryLimit(null).build(); + + ByteSizeValue maxLimit = new ByteSizeValue(randomIntBetween(500, 1000), ByteSizeUnit.MB); + if (maxLimit.compareTo(defaultLimitConfig.getModelMemoryLimit()) < 0) { + assertThat(maxLimit, + equalTo(new DataFrameAnalyticsConfig.Builder(defaultLimitConfig, maxLimit).build().getModelMemoryLimit())); + } else { + assertThat(defaultLimitConfig.getModelMemoryLimit(), + equalTo(new DataFrameAnalyticsConfig.Builder(defaultLimitConfig, maxLimit).build().getModelMemoryLimit())); + } + } + + public void testExplicitModelMemoryLimitTooHigh() { + + ByteSizeValue configuredLimit = new ByteSizeValue(randomIntBetween(5, 10), ByteSizeUnit.GB); + DataFrameAnalyticsConfig explicitLimitConfig = createRandomBuilder("foo").setModelMemoryLimit(configuredLimit).build(); + + ByteSizeValue maxLimit = new ByteSizeValue(randomIntBetween(500, 1000), ByteSizeUnit.MB); + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new DataFrameAnalyticsConfig.Builder(explicitLimitConfig, maxLimit).build()); + assertThat(e.getMessage(), startsWith("model_memory_limit")); + assertThat(e.getMessage(), containsString("must be less than the value of the xpack.ml.max_model_memory_limit setting")); + } + + public void assertTooSmall(IllegalArgumentException e) { + assertThat(e.getMessage(), is("[model_memory_limit] must be at least [1mb]")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java new file mode 100644 index 0000000000000..bf8ce4c8a99b0 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsDestTests.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.dataframe; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.indices.InvalidIndexNameException; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.AbstractSerializingTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class DataFrameAnalyticsDestTests extends AbstractSerializingTestCase { + + @Override + protected DataFrameAnalyticsDest doParseInstance(XContentParser parser) throws IOException { + return DataFrameAnalyticsDest.createParser(false).apply(parser, null); + } + + @Override + protected DataFrameAnalyticsDest createTestInstance() { + return createRandom(); + } + + public static DataFrameAnalyticsDest createRandom() { + String index = randomAlphaOfLength(10); + String resultsField = randomBoolean() ? null : randomAlphaOfLength(10); + return new DataFrameAnalyticsDest(index, resultsField); + } + + @Override + protected Writeable.Reader instanceReader() { + return DataFrameAnalyticsDest::new; + } + + public void testValidate_GivenIndexWithFunkyChars() { + expectThrows(InvalidIndexNameException.class, () -> new DataFrameAnalyticsDest("