diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/StoredAsyncResponse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/StoredAsyncResponse.java similarity index 97% rename from x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/StoredAsyncResponse.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/StoredAsyncResponse.java index 1d4d9a1f3eb3e..3d7cca850e4da 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/StoredAsyncResponse.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/StoredAsyncResponse.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.eql.async; +package org.elasticsearch.xpack.core.async; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.io.stream.StreamInput; @@ -13,7 +13,6 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.async.AsyncResponse; import java.io.IOException; import java.util.Objects; diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/StoredAsyncTask.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/StoredAsyncTask.java similarity index 90% rename from x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/StoredAsyncTask.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/StoredAsyncTask.java index b6e1a54910555..08e9b613b352c 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/StoredAsyncTask.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/StoredAsyncTask.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.eql.async; +package org.elasticsearch.xpack.core.async; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionResponse; @@ -13,8 +13,6 @@ import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskManager; -import org.elasticsearch.xpack.core.async.AsyncExecutionId; -import org.elasticsearch.xpack.core.async.AsyncTask; import java.util.ArrayList; import java.util.List; @@ -71,7 +69,7 @@ public synchronized void removeCompletionListener(ActionListener liste /** * This method is called when the task is finished successfully before unregistering the task and storing the results */ - protected synchronized void onResponse(Response response) { + public synchronized void onResponse(Response response) { for (ActionListener listener : completionListeners) { listener.onResponse(response); } @@ -80,7 +78,7 @@ protected synchronized void onResponse(Response response) { /** * This method is called when the task failed before unregistering the task and storing the results */ - protected synchronized void onFailure(Exception e) { + public synchronized void onFailure(Exception e) { for (ActionListener listener : completionListeners) { listener.onFailure(e); } @@ -89,7 +87,7 @@ protected synchronized void onFailure(Exception e) { /** * Return currently available partial or the final results */ - protected abstract Response getCurrentResult(); + public abstract Response getCurrentResult(); @Override public void cancelTask(TaskManager taskManager, Runnable runnable, String reason) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/sql/SqlAsyncActionNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/sql/SqlAsyncActionNames.java new file mode 100644 index 0000000000000..70aa8ccaa2e58 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/sql/SqlAsyncActionNames.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.sql; + +/** + * Exposes SQL async action names for the RBAC engine + */ +public class SqlAsyncActionNames { + public static final String SQL_ASYNC_GET_RESULT_ACTION_NAME = "indices:data/read/sql/async/get"; + public static final String SQL_ASYNC_GET_STATUS_ACTION_NAME = "cluster:monitor/xpack/sql/async/status"; +} diff --git a/x-pack/plugin/eql/src/internalClusterTest/java/org/elasticsearch/xpack/eql/action/AsyncEqlSearchActionIT.java b/x-pack/plugin/eql/src/internalClusterTest/java/org/elasticsearch/xpack/eql/action/AsyncEqlSearchActionIT.java index 1f5bd986fa08a..8f5e3ebeb03b1 100644 --- a/x-pack/plugin/eql/src/internalClusterTest/java/org/elasticsearch/xpack/eql/action/AsyncEqlSearchActionIT.java +++ b/x-pack/plugin/eql/src/internalClusterTest/java/org/elasticsearch/xpack/eql/action/AsyncEqlSearchActionIT.java @@ -33,7 +33,7 @@ import org.elasticsearch.xpack.core.async.DeleteAsyncResultAction; import org.elasticsearch.xpack.core.async.DeleteAsyncResultRequest; import org.elasticsearch.xpack.core.async.GetAsyncResultRequest; -import org.elasticsearch.xpack.eql.async.StoredAsyncResponse; +import org.elasticsearch.xpack.core.async.StoredAsyncResponse; import org.elasticsearch.xpack.eql.plugin.EqlAsyncGetResultAction; import org.hamcrest.BaseMatcher; import org.hamcrest.Description; diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchResponse.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchResponse.java index 6f07dd972cef4..1666f968540ad 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchResponse.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchResponse.java @@ -30,6 +30,7 @@ import org.elasticsearch.index.get.GetResult; import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.search.SearchHits; +import org.elasticsearch.xpack.ql.async.QlStatusResponse; import java.io.IOException; import java.util.Collections; @@ -41,7 +42,7 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; -public class EqlSearchResponse extends ActionResponse implements ToXContentObject { +public class EqlSearchResponse extends ActionResponse implements ToXContentObject, QlStatusResponse.AsyncStatus { private final Hits hits; private final long tookInMillis; @@ -150,14 +151,17 @@ public Hits hits() { return hits; } + @Override public String id() { return asyncExecutionId; } + @Override public boolean isRunning() { return isRunning; } + @Override public boolean isPartial() { return isPartial; } diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchTask.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchTask.java index 0f28a09f37295..41c715c950c32 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchTask.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlSearchTask.java @@ -10,7 +10,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.core.async.AsyncExecutionId; -import org.elasticsearch.xpack.eql.async.StoredAsyncTask; +import org.elasticsearch.xpack.core.async.StoredAsyncTask; import java.util.Map; @@ -27,19 +27,4 @@ public EqlSearchResponse getCurrentResult() { return new EqlSearchResponse(EqlSearchResponse.Hits.EMPTY, System.currentTimeMillis() - getStartTime(), false, getExecutionId().getEncoded(), true, true); } - - - /** - * Returns the status from {@link EqlSearchTask} - */ - public static EqlStatusResponse getStatusResponse(EqlSearchTask asyncTask) { - return new EqlStatusResponse( - asyncTask.getExecutionId().getEncoded(), - true, - true, - asyncTask.getStartTime(), - asyncTask.getExpirationTimeMillis(), - null - ); - } } diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/EqlAsyncGetStatusAction.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/EqlAsyncGetStatusAction.java index a43f5fc5c067b..24d6fd6fee68a 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/EqlAsyncGetStatusAction.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/EqlAsyncGetStatusAction.java @@ -7,13 +7,13 @@ package org.elasticsearch.xpack.eql.plugin; import org.elasticsearch.action.ActionType; -import org.elasticsearch.xpack.eql.action.EqlStatusResponse; +import org.elasticsearch.xpack.ql.async.QlStatusResponse; -public class EqlAsyncGetStatusAction extends ActionType { +public class EqlAsyncGetStatusAction extends ActionType { public static final EqlAsyncGetStatusAction INSTANCE = new EqlAsyncGetStatusAction(); public static final String NAME = "cluster:monitor/eql/async/status"; private EqlAsyncGetStatusAction() { - super(NAME, EqlStatusResponse::new); + super(NAME, QlStatusResponse::new); } } diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/EqlPlugin.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/EqlPlugin.java index 7e01382a6a663..f42c6f5988133 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/EqlPlugin.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/EqlPlugin.java @@ -87,7 +87,7 @@ public List> getSettings() { return List.of( new ActionHandler<>(EqlSearchAction.INSTANCE, TransportEqlSearchAction.class), new ActionHandler<>(EqlStatsAction.INSTANCE, TransportEqlStatsAction.class), - new ActionHandler<>(EqlAsyncGetResultAction.INSTANCE, TransportEqlAsyncGetResultAction.class), + new ActionHandler<>(EqlAsyncGetResultAction.INSTANCE, TransportEqlAsyncGetResultsAction.class), new ActionHandler<>(EqlAsyncGetStatusAction.INSTANCE, TransportEqlAsyncGetStatusAction.class), new ActionHandler<>(XPackUsageFeatureAction.EQL, EqlUsageTransportAction.class), new ActionHandler<>(XPackInfoFeatureAction.EQL, EqlInfoTransportAction.class) diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetResultsAction.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetResultsAction.java new file mode 100644 index 0000000000000..9027586d96bc4 --- /dev/null +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetResultsAction.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.eql.plugin; + +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.eql.EqlAsyncActionNames; +import org.elasticsearch.xpack.eql.action.EqlSearchResponse; +import org.elasticsearch.xpack.eql.action.EqlSearchTask; +import org.elasticsearch.xpack.ql.plugin.AbstractTransportQlAsyncGetResultsAction; + +public class TransportEqlAsyncGetResultsAction extends AbstractTransportQlAsyncGetResultsAction { + + @Inject + public TransportEqlAsyncGetResultsAction(TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + NamedWriteableRegistry registry, + Client client, + ThreadPool threadPool, + BigArrays bigArrays) { + super(EqlAsyncActionNames.EQL_ASYNC_GET_RESULT_ACTION_NAME, transportService, actionFilters, clusterService, registry, client, + threadPool, bigArrays, EqlSearchTask.class); + } + + @Override + public Writeable.Reader responseReader() { + return EqlSearchResponse::new; + } +} diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetStatusAction.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetStatusAction.java index b2514a947112c..57b3a94a5f17e 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetStatusAction.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetStatusAction.java @@ -6,39 +6,21 @@ */ package org.elasticsearch.xpack.eql.plugin; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.Client; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.XPackPlugin; -import org.elasticsearch.xpack.core.async.AsyncExecutionId; -import org.elasticsearch.xpack.core.async.AsyncTaskIndexService; -import org.elasticsearch.xpack.core.async.GetAsyncStatusRequest; import org.elasticsearch.xpack.eql.action.EqlSearchResponse; import org.elasticsearch.xpack.eql.action.EqlSearchTask; -import org.elasticsearch.xpack.eql.action.EqlStatusResponse; -import org.elasticsearch.xpack.eql.async.StoredAsyncResponse; +import org.elasticsearch.xpack.ql.plugin.AbstractTransportQlAsyncGetStatusAction; -import java.util.Objects; - -import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; - - -public class TransportEqlAsyncGetStatusAction extends HandledTransportAction { - private final TransportService transportService; - private final ClusterService clusterService; - private final AsyncTaskIndexService> store; +public class TransportEqlAsyncGetStatusAction extends AbstractTransportQlAsyncGetStatusAction { @Inject public TransportEqlAsyncGetStatusAction(TransportService transportService, ActionFilters actionFilters, @@ -47,31 +29,12 @@ public TransportEqlAsyncGetStatusAction(TransportService transportService, Client client, ThreadPool threadPool, BigArrays bigArrays) { - super(EqlAsyncGetStatusAction.NAME, transportService, actionFilters, GetAsyncStatusRequest::new); - this.transportService = transportService; - this.clusterService = clusterService; - Writeable.Reader> reader = in -> new StoredAsyncResponse<>(EqlSearchResponse::new, in); - this.store = new AsyncTaskIndexService<>(XPackPlugin.ASYNC_RESULTS_INDEX, clusterService, - threadPool.getThreadContext(), client, ASYNC_SEARCH_ORIGIN, reader, registry, bigArrays); + super(EqlAsyncGetStatusAction.NAME, transportService, actionFilters, clusterService, registry, client, threadPool, bigArrays, + EqlSearchTask.class); } @Override - protected void doExecute(Task task, GetAsyncStatusRequest request, ActionListener listener) { - AsyncExecutionId searchId = AsyncExecutionId.decode(request.getId()); - DiscoveryNode node = clusterService.state().nodes().get(searchId.getTaskId().getNodeId()); - DiscoveryNode localNode = clusterService.state().getNodes().getLocalNode(); - if (node == null || Objects.equals(node, localNode)) { - store.retrieveStatus( - request, - taskManager, - EqlSearchTask.class, - EqlSearchTask::getStatusResponse, - EqlStatusResponse::getStatusFromStoredSearch, - listener - ); - } else { - transportService.sendRequest(node, EqlAsyncGetStatusAction.NAME, request, - new ActionListenerResponseHandler<>(listener, EqlStatusResponse::new, ThreadPool.Names.SAME)); - } + protected Writeable.Reader responseReader() { + return EqlSearchResponse::new; } } diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlSearchAction.java b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlSearchAction.java index 06d19d6552c3f..a64d0a0f1a43e 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlSearchAction.java +++ b/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlSearchAction.java @@ -35,11 +35,11 @@ import org.elasticsearch.xpack.eql.action.EqlSearchRequest; import org.elasticsearch.xpack.eql.action.EqlSearchResponse; import org.elasticsearch.xpack.eql.action.EqlSearchTask; -import org.elasticsearch.xpack.eql.async.AsyncTaskManagementService; import org.elasticsearch.xpack.eql.execution.PlanExecutor; import org.elasticsearch.xpack.eql.parser.ParserParams; import org.elasticsearch.xpack.eql.session.EqlConfiguration; import org.elasticsearch.xpack.eql.session.Results; +import org.elasticsearch.xpack.ql.async.AsyncTaskManagementService; import org.elasticsearch.xpack.ql.expression.Order; import java.io.IOException; diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/AsyncTaskManagementService.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java similarity index 98% rename from x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/AsyncTaskManagementService.java rename to x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java index 6d5914ab9ed01..60b64458434ac 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/async/AsyncTaskManagementService.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementService.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.eql.async; +package org.elasticsearch.xpack.ql.async; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -34,6 +34,8 @@ import org.elasticsearch.xpack.core.async.AsyncExecutionId; import org.elasticsearch.xpack.core.async.AsyncTask; import org.elasticsearch.xpack.core.async.AsyncTaskIndexService; +import org.elasticsearch.xpack.core.async.StoredAsyncResponse; +import org.elasticsearch.xpack.core.async.StoredAsyncTask; import java.io.IOException; import java.util.Map; diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlStatusResponse.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/QlStatusResponse.java similarity index 68% rename from x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlStatusResponse.java rename to x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/QlStatusResponse.java index 6ea2091485d97..1ac1943d2f940 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/action/EqlStatusResponse.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/async/QlStatusResponse.java @@ -4,17 +4,18 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.eql.action; +package org.elasticsearch.xpack.ql.async; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.StatusToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.async.StoredAsyncResponse; import org.elasticsearch.xpack.core.search.action.SearchStatusResponse; -import org.elasticsearch.xpack.eql.async.StoredAsyncResponse; import java.io.IOException; import java.util.Objects; @@ -22,9 +23,9 @@ import static org.elasticsearch.rest.RestStatus.OK; /** - * A response for eql search status request + * A response for *QL search status request */ -public class EqlStatusResponse extends ActionResponse implements SearchStatusResponse, StatusToXContentObject { +public class QlStatusResponse extends ActionResponse implements SearchStatusResponse, StatusToXContentObject { private final String id; private final boolean isRunning; private final boolean isPartial; @@ -32,12 +33,20 @@ public class EqlStatusResponse extends ActionResponse implements SearchStatusRes private final long expirationTimeMillis; private final RestStatus completionStatus; - public EqlStatusResponse(String id, - boolean isRunning, - boolean isPartial, - Long startTimeMillis, - long expirationTimeMillis, - RestStatus completionStatus) { + public interface AsyncStatus { + String id(); + + boolean isRunning(); + + boolean isPartial(); + } + + public QlStatusResponse(String id, + boolean isRunning, + boolean isPartial, + Long startTimeMillis, + long expirationTimeMillis, + RestStatus completionStatus) { this.id = id; this.isRunning = isRunning; this.isPartial = isPartial; @@ -47,40 +56,40 @@ public EqlStatusResponse(String id, } /** - * Get status from the stored eql search response + * Get status from the stored Ql search response * @param storedResponse - a response from a stored search * @param expirationTimeMillis – expiration time in milliseconds * @param id – encoded async search id * @return a status response */ - public static EqlStatusResponse getStatusFromStoredSearch(StoredAsyncResponse storedResponse, - long expirationTimeMillis, String id) { - EqlSearchResponse searchResponse = storedResponse.getResponse(); + public static QlStatusResponse getStatusFromStoredSearch(StoredAsyncResponse storedResponse, + long expirationTimeMillis, String id) { + S searchResponse = storedResponse.getResponse(); if (searchResponse != null) { - assert searchResponse.isRunning() == false : "Stored eql search response must have a completed status!"; - return new EqlStatusResponse( + assert searchResponse.isRunning() == false : "Stored Ql search response must have a completed status!"; + return new QlStatusResponse( searchResponse.id(), false, searchResponse.isPartial(), - null, // we dont' store in the index start time for completed response + null, // we don't store in the index the start time for completed response expirationTimeMillis, RestStatus.OK ); } else { Exception exc = storedResponse.getException(); - assert exc != null : "Stored eql response must either have a search response or an exception!"; - return new EqlStatusResponse( + assert exc != null : "Stored Ql response must either have a search response or an exception!"; + return new QlStatusResponse( id, false, false, - null, // we dont' store in the index start time for completed response + null, // we don't store in the index the start time for completed response expirationTimeMillis, ExceptionsHelper.status(exc) ); } } - public EqlStatusResponse(StreamInput in) throws IOException { + public QlStatusResponse(StreamInput in) throws IOException { this.id = in.readString(); this.isRunning = in.readBoolean(); this.isPartial = in.readBoolean(); @@ -109,15 +118,17 @@ public RestStatus status() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field("id", id); - builder.field("is_running", isRunning); - builder.field("is_partial", isPartial); - if (startTimeMillis != null) { // start time is available only for a running eql search - builder.timeField("start_time_in_millis", "start_time", startTimeMillis); - } - builder.timeField("expiration_time_in_millis", "expiration_time", expirationTimeMillis); - if (isRunning == false) { // completion status is available only for a completed eql search - builder.field("completion_status", completionStatus.getStatus()); + { + builder.field("id", id); + builder.field("is_running", isRunning); + builder.field("is_partial", isPartial); + if (startTimeMillis != null) { // start time is available only for a running eql search + builder.timeField("start_time_in_millis", "start_time", startTimeMillis); + } + builder.timeField("expiration_time_in_millis", "expiration_time", expirationTimeMillis); + if (isRunning == false) { // completion status is available only for a completed eql search + builder.field("completion_status", completionStatus.getStatus()); + } } builder.endObject(); return builder; @@ -127,7 +138,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public boolean equals(Object obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; - EqlStatusResponse other = (EqlStatusResponse) obj; + QlStatusResponse other = (QlStatusResponse) obj; return id.equals(other.id) && isRunning == other.isRunning && isPartial == other.isPartial diff --git a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetResultAction.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/plugin/AbstractTransportQlAsyncGetResultsAction.java similarity index 52% rename from x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetResultAction.java rename to x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/plugin/AbstractTransportQlAsyncGetResultsAction.java index 8f9a231e2f1c6..a9c14257382d8 100644 --- a/x-pack/plugin/eql/src/main/java/org/elasticsearch/xpack/eql/plugin/TransportEqlAsyncGetResultAction.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/plugin/AbstractTransportQlAsyncGetResultsAction.java @@ -4,16 +4,16 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.eql.plugin; +package org.elasticsearch.xpack.ql.plugin; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.BigArrays; @@ -24,48 +24,52 @@ import org.elasticsearch.xpack.core.async.AsyncResultsService; import org.elasticsearch.xpack.core.async.AsyncTaskIndexService; import org.elasticsearch.xpack.core.async.GetAsyncResultRequest; -import org.elasticsearch.xpack.core.eql.EqlAsyncActionNames; -import org.elasticsearch.xpack.eql.action.EqlSearchResponse; -import org.elasticsearch.xpack.eql.action.EqlSearchTask; -import org.elasticsearch.xpack.eql.async.AsyncTaskManagementService; -import org.elasticsearch.xpack.eql.async.StoredAsyncResponse; +import org.elasticsearch.xpack.ql.async.AsyncTaskManagementService; +import org.elasticsearch.xpack.core.async.StoredAsyncResponse; +import org.elasticsearch.xpack.core.async.StoredAsyncTask; import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; -public class TransportEqlAsyncGetResultAction extends HandledTransportAction { - private final AsyncResultsService> resultsService; +public abstract class AbstractTransportQlAsyncGetResultsAction> + extends HandledTransportAction { + private final String actionName; + private final AsyncResultsService> resultsService; private final TransportService transportService; - @Inject - public TransportEqlAsyncGetResultAction(TransportService transportService, - ActionFilters actionFilters, - ClusterService clusterService, - NamedWriteableRegistry registry, - Client client, - ThreadPool threadPool, - BigArrays bigArrays) { - super(EqlAsyncActionNames.EQL_ASYNC_GET_RESULT_ACTION_NAME, transportService, actionFilters, GetAsyncResultRequest::new); + public AbstractTransportQlAsyncGetResultsAction(String actionName, + TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + NamedWriteableRegistry registry, + Client client, + ThreadPool threadPool, + BigArrays bigArrays, + Class asynkTaskClass) { + super(actionName, transportService, actionFilters, GetAsyncResultRequest::new); + this.actionName = actionName; this.transportService = transportService; - this.resultsService = createResultsService(transportService, clusterService, registry, client, threadPool, bigArrays); + this.resultsService = createResultsService(transportService, clusterService, registry, client, threadPool, bigArrays, + asynkTaskClass); } - static AsyncResultsService> createResultsService( + AsyncResultsService> createResultsService( TransportService transportService, ClusterService clusterService, NamedWriteableRegistry registry, Client client, ThreadPool threadPool, - BigArrays bigArrays) { - Writeable.Reader> reader = in -> new StoredAsyncResponse<>(EqlSearchResponse::new, in); - AsyncTaskIndexService> store = new AsyncTaskIndexService<>(XPackPlugin.ASYNC_RESULTS_INDEX, + BigArrays bigArrays, + Class asyncTaskClass) { + Writeable.Reader> reader = in -> new StoredAsyncResponse<>(responseReader(), in); + AsyncTaskIndexService> store = new AsyncTaskIndexService<>(XPackPlugin.ASYNC_RESULTS_INDEX, clusterService, threadPool.getThreadContext(), client, ASYNC_SEARCH_ORIGIN, reader, registry, bigArrays); - return new AsyncResultsService<>(store, false, EqlSearchTask.class, + return new AsyncResultsService<>(store, false, asyncTaskClass, (task, listener, timeout) -> AsyncTaskManagementService.addCompletionListener(threadPool, task, listener, timeout), transportService.getTaskManager(), clusterService); } @Override - protected void doExecute(Task task, GetAsyncResultRequest request, ActionListener listener) { + protected void doExecute(Task task, GetAsyncResultRequest request, ActionListener listener) { DiscoveryNode node = resultsService.getNode(request.getId()); if (node == null || resultsService.isLocalNode(node)) { resultsService.retrieveResult(request, ActionListener.wrap( @@ -79,8 +83,10 @@ protected void doExecute(Task task, GetAsyncResultRequest request, ActionListene listener::onFailure )); } else { - transportService.sendRequest(node, EqlAsyncActionNames.EQL_ASYNC_GET_RESULT_ACTION_NAME, request, - new ActionListenerResponseHandler<>(listener, EqlSearchResponse::new, ThreadPool.Names.SAME)); + transportService.sendRequest(node, actionName, request, + new ActionListenerResponseHandler<>(listener, responseReader(), ThreadPool.Names.SAME)); } } + + public abstract Writeable.Reader responseReader(); } diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/plugin/AbstractTransportQlAsyncGetStatusAction.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/plugin/AbstractTransportQlAsyncGetStatusAction.java new file mode 100644 index 0000000000000..daea69a9823dc --- /dev/null +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/plugin/AbstractTransportQlAsyncGetStatusAction.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ql.plugin; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.core.async.AsyncExecutionId; +import org.elasticsearch.xpack.core.async.AsyncTaskIndexService; +import org.elasticsearch.xpack.core.async.GetAsyncStatusRequest; +import org.elasticsearch.xpack.core.async.StoredAsyncResponse; +import org.elasticsearch.xpack.core.async.StoredAsyncTask; +import org.elasticsearch.xpack.ql.async.QlStatusResponse; + +import java.util.Objects; + +import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; + + +public abstract class AbstractTransportQlAsyncGetStatusAction> extends HandledTransportAction { + private final String actionName; + private final TransportService transportService; + private final ClusterService clusterService; + private final Class asyncTaskClass; + private final AsyncTaskIndexService> store; + + public AbstractTransportQlAsyncGetStatusAction(String actionName, + TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + NamedWriteableRegistry registry, + Client client, + ThreadPool threadPool, + BigArrays bigArrays, + Class asyncTaskClass) { + super(actionName, transportService, actionFilters, GetAsyncStatusRequest::new); + this.actionName = actionName; + this.transportService = transportService; + this.clusterService = clusterService; + this.asyncTaskClass = asyncTaskClass; + Writeable.Reader> reader = in -> new StoredAsyncResponse<>(responseReader(), in); + this.store = new AsyncTaskIndexService<>(XPackPlugin.ASYNC_RESULTS_INDEX, clusterService, + threadPool.getThreadContext(), client, ASYNC_SEARCH_ORIGIN, reader, registry, bigArrays); + } + + @Override + protected void doExecute(Task task, GetAsyncStatusRequest request, ActionListener listener) { + AsyncExecutionId searchId = AsyncExecutionId.decode(request.getId()); + DiscoveryNode node = clusterService.state().nodes().get(searchId.getTaskId().getNodeId()); + DiscoveryNode localNode = clusterService.state().getNodes().getLocalNode(); + if (node == null || Objects.equals(node, localNode)) { + store.retrieveStatus( + request, + taskManager, + asyncTaskClass, + AbstractTransportQlAsyncGetStatusAction::getStatusResponse, + QlStatusResponse::getStatusFromStoredSearch, + listener + ); + } else { + transportService.sendRequest(node, actionName, request, + new ActionListenerResponseHandler<>(listener, QlStatusResponse::new, ThreadPool.Names.SAME)); + } + } + + private static QlStatusResponse getStatusResponse(StoredAsyncTask asyncTask) { + return new QlStatusResponse( + asyncTask.getExecutionId().getEncoded(), + true, + true, + asyncTask.getStartTime(), + asyncTask.getExpirationTimeMillis(), + null + ); + } + + protected abstract Writeable.Reader responseReader(); +} diff --git a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/action/EqlStatusResponseTests.java b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/action/QlStatusResponseTests.java similarity index 82% rename from x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/action/EqlStatusResponseTests.java rename to x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/action/QlStatusResponseTests.java index 2b8aecf8fade3..a8cc71e78500c 100644 --- a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/action/EqlStatusResponseTests.java +++ b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/action/QlStatusResponseTests.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.eql.action; +package org.elasticsearch.xpack.ql.action; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.Writeable; @@ -13,16 +13,17 @@ import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.ql.async.QlStatusResponse; import java.io.IOException; import java.util.Date; import static org.elasticsearch.xpack.core.async.GetAsyncResultRequestTests.randomSearchId; -public class EqlStatusResponseTests extends AbstractWireSerializingTestCase { +public class QlStatusResponseTests extends AbstractWireSerializingTestCase { @Override - protected EqlStatusResponse createTestInstance() { + protected QlStatusResponse createTestInstance() { String id = randomSearchId(); boolean isRunning = randomBoolean(); boolean isPartial = isRunning ? randomBoolean() : false; @@ -30,21 +31,21 @@ protected EqlStatusResponse createTestInstance() { Long startTimeMillis = randomBoolean() ? null : randomDate; long expirationTimeMillis = startTimeMillis == null ? randomDate : startTimeMillis + 3600000L; RestStatus completionStatus = isRunning ? null : randomBoolean() ? RestStatus.OK : RestStatus.SERVICE_UNAVAILABLE; - return new EqlStatusResponse(id, isRunning, isPartial, startTimeMillis, expirationTimeMillis, completionStatus); + return new QlStatusResponse(id, isRunning, isPartial, startTimeMillis, expirationTimeMillis, completionStatus); } @Override - protected Writeable.Reader instanceReader() { - return EqlStatusResponse::new; + protected Writeable.Reader instanceReader() { + return QlStatusResponse::new; } @Override - protected EqlStatusResponse mutateInstance(EqlStatusResponse instance) { + protected QlStatusResponse mutateInstance(QlStatusResponse instance) { // return a response with the opposite running status boolean isRunning = instance.isRunning() == false; boolean isPartial = isRunning ? randomBoolean() : false; RestStatus completionStatus = isRunning ? null : randomBoolean() ? RestStatus.OK : RestStatus.SERVICE_UNAVAILABLE; - return new EqlStatusResponse( + return new QlStatusResponse( instance.getId(), isRunning, isPartial, @@ -55,7 +56,7 @@ protected EqlStatusResponse mutateInstance(EqlStatusResponse instance) { } public void testToXContent() throws IOException { - EqlStatusResponse response = createTestInstance(); + QlStatusResponse response = createTestInstance(); try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) { String expectedJson = "{\n" + " \"id\" : \"" + response.getId() + "\",\n" + diff --git a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/async/AsyncTaskManagementServiceTests.java b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementServiceTests.java similarity index 98% rename from x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/async/AsyncTaskManagementServiceTests.java rename to x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementServiceTests.java index fb0c4ea79b733..068f7900f2860 100644 --- a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/async/AsyncTaskManagementServiceTests.java +++ b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/async/AsyncTaskManagementServiceTests.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.eql.async; +package org.elasticsearch.xpack.ql.async; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; @@ -23,6 +23,8 @@ import org.elasticsearch.xpack.core.async.AsyncResultsService; import org.elasticsearch.xpack.core.async.AsyncTaskIndexService; import org.elasticsearch.xpack.core.async.GetAsyncResultRequest; +import org.elasticsearch.xpack.core.async.StoredAsyncResponse; +import org.elasticsearch.xpack.core.async.StoredAsyncTask; import org.junit.After; import org.junit.Before; @@ -35,7 +37,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import static org.elasticsearch.xpack.eql.async.AsyncTaskManagementService.addCompletionListener; +import static org.elasticsearch.xpack.ql.async.AsyncTaskManagementService.addCompletionListener; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; diff --git a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/async/StoredAsyncResponseTests.java b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/async/StoredAsyncResponseTests.java similarity index 96% rename from x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/async/StoredAsyncResponseTests.java rename to x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/async/StoredAsyncResponseTests.java index 4832cdd56bb84..81e72556ce6b3 100644 --- a/x-pack/plugin/eql/src/test/java/org/elasticsearch/xpack/eql/async/StoredAsyncResponseTests.java +++ b/x-pack/plugin/ql/src/test/java/org/elasticsearch/xpack/ql/async/StoredAsyncResponseTests.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.eql.async; +package org.elasticsearch.xpack.ql.async; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; @@ -14,6 +14,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.async.StoredAsyncResponse; import java.io.IOException; import java.util.Collections; diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 89cd0b41dc274..af343b22adeed 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -303,6 +303,7 @@ public class Constants { "cluster:monitor/xpack/searchable_snapshots/stats", "cluster:monitor/xpack/security/saml/metadata", "cluster:monitor/xpack/spatial/stats", + "cluster:monitor/xpack/sql/async/status", // org.elasticsearch.xpack.core.sql.SqlAsyncActionNames.SQL_ASYNC_GET_STATUS_ACTION_NAME "cluster:monitor/xpack/sql/stats/dist", "cluster:monitor/xpack/ssl/certificates/get", "cluster:monitor/xpack/usage", @@ -414,6 +415,7 @@ public class Constants { "indices:data/read/sql", "indices:data/read/sql/close_cursor", "indices:data/read/sql/translate", + "indices:data/read/sql/async/get", // org.elasticsearch.xpack.core.sql.SqlAsyncActionNames.SQL_ASYNC_GET_RESULT_ACTION_NAME "indices:data/read/tv", "indices:data/read/xpack/ccr/shard_changes", "indices:data/read/xpack/enrich/coordinate_lookups", diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/RBACEngine.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/RBACEngine.java index 95d46bb251c36..8f658e290ed81 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/RBACEngine.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/RBACEngine.java @@ -74,6 +74,7 @@ import org.elasticsearch.xpack.core.security.authz.privilege.Privilege; import org.elasticsearch.xpack.core.security.support.StringMatcher; import org.elasticsearch.xpack.core.security.user.User; +import org.elasticsearch.xpack.core.sql.SqlAsyncActionNames; import org.elasticsearch.xpack.security.authc.ApiKeyService; import org.elasticsearch.xpack.security.authc.esnative.ReservedRealm; import org.elasticsearch.xpack.security.authz.store.CompositeRolesStore; @@ -629,6 +630,7 @@ private static boolean isAsyncRelatedAction(String action) { return action.equals(SubmitAsyncSearchAction.NAME) || action.equals(GetAsyncSearchAction.NAME) || action.equals(DeleteAsyncResultAction.NAME) || - action.equals(EqlAsyncActionNames.EQL_ASYNC_GET_RESULT_ACTION_NAME); + action.equals(EqlAsyncActionNames.EQL_ASYNC_GET_RESULT_ACTION_NAME) || + action.equals(SqlAsyncActionNames.SQL_ASYNC_GET_RESULT_ACTION_NAME); } } diff --git a/x-pack/plugin/sql/qa/server/security/build.gradle b/x-pack/plugin/sql/qa/server/security/build.gradle index 7bbf9d72057e1..375c654645955 100644 --- a/x-pack/plugin/sql/qa/server/security/build.gradle +++ b/x-pack/plugin/sql/qa/server/security/build.gradle @@ -38,6 +38,9 @@ subprojects { /* Setup the one admin user that we run the tests as. * Tests use "run as" to get different users. */ user username: "test_admin", password: "x-pack-test-password" + user username: "user1", password: 'x-pack-test-password', role: "user1" + user username: "user2", password: 'x-pack-test-password', role: "user2" + user username: "manage_user", password: 'x-pack-test-password', role: "manage_user" } File testArtifactsDir = project.file("$buildDir/testArtifacts") diff --git a/x-pack/plugin/sql/qa/server/security/roles.yml b/x-pack/plugin/sql/qa/server/security/roles.yml index 141314e23f024..01c9698681968 100644 --- a/x-pack/plugin/sql/qa/server/security/roles.yml +++ b/x-pack/plugin/sql/qa/server/security/roles.yml @@ -89,3 +89,29 @@ no_get_index: privileges: [monitor] - names: bort privileges: [monitor] + +user1: + cluster: + - cluster:monitor/main + indices: + - names: ['index-user1', 'index' ] + privileges: + - read + - write + - create_index + - indices:admin/refresh + +user2: + cluster: + - cluster:monitor/main + indices: + - names: [ 'index-user2', 'index' ] + privileges: + - read + - write + - create_index + - indices:admin/refresh + +manage_user: + cluster: + - manage diff --git a/x-pack/plugin/sql/qa/server/security/src/test/java/org/elasticsearch/xpack/sql/qa/security/RestSqlSecurityAsyncIT.java b/x-pack/plugin/sql/qa/server/security/src/test/java/org/elasticsearch/xpack/sql/qa/security/RestSqlSecurityAsyncIT.java new file mode 100644 index 0000000000000..deeac8212161d --- /dev/null +++ b/x-pack/plugin/sql/qa/server/security/src/test/java/org/elasticsearch/xpack/sql/qa/security/RestSqlSecurityAsyncIT.java @@ -0,0 +1,202 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.sql.qa.security; + +import org.apache.http.util.EntityUtils; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.core.async.AsyncExecutionId; +import org.elasticsearch.xpack.sql.qa.rest.BaseRestSqlTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.Map; + +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField.RUN_AS_USER_HEADER; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +public class RestSqlSecurityAsyncIT extends ESRestTestCase { + + @Before + public void indexDocuments() throws IOException { + createIndex("index", Settings.EMPTY); + index("index", "0", "event_type", "my_event", "@timestamp", "2020-04-09T12:35:48Z", "val", 0); + refresh("index"); + + createIndex("index-user1", Settings.EMPTY); + index("index-user1", "0", "event_type", "my_event", "@timestamp", "2020-04-09T12:35:48Z", "val", 0); + refresh("index-user1"); + + createIndex("index-user2", Settings.EMPTY); + index("index-user2", "0", "event_type", "my_event", "@timestamp", "2020-04-09T12:35:48Z", "val", 0); + refresh("index-user2"); + } + + @Override + protected Settings restClientSettings() { + return RestSqlIT.securitySettings(); + } + + @Override + protected String getProtocol() { + return RestSqlIT.SSL_ENABLED ? "https" : "http"; + } + + public void testWithUsers() throws Exception { + testCase("user1", "user2"); + testCase("user2", "user1"); + } + + private void testCase(String user, String otherUser) throws Exception { + for (String indexName : new String[] { "index", "index-" + user }) { + Response submitResp = submitAsyncSqlSearch( + "SELECT event_type FROM \"" + indexName + "\" WHERE val=0", + TimeValue.timeValueSeconds(10), + user + ); + assertOK(submitResp); + String id = extractResponseId(submitResp); + Response getResp = getAsyncSqlSearch(id, user); + assertOK(getResp); + + // other cannot access the result + ResponseException exc = expectThrows(ResponseException.class, () -> getAsyncSqlSearch(id, otherUser)); + assertThat(exc.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + // other cannot delete the result + exc = expectThrows(ResponseException.class, () -> deleteAsyncSqlSearch(id, otherUser)); + assertThat(exc.getResponse().getStatusLine().getStatusCode(), equalTo(404)); + + // other and user cannot access the result from direct get calls + AsyncExecutionId searchId = AsyncExecutionId.decode(id); + for (String runAs : new String[] { user, otherUser }) { + exc = expectThrows(ResponseException.class, () -> get(XPackPlugin.ASYNC_RESULTS_INDEX, searchId.getDocId(), runAs)); + assertThat(exc.getResponse().getStatusLine().getStatusCode(), equalTo(403)); + assertThat(exc.getMessage(), containsString("unauthorized")); + } + + Response delResp = deleteAsyncSqlSearch(id, user); + assertOK(delResp); + } + ResponseException exc = expectThrows( + ResponseException.class, + () -> submitAsyncSqlSearch("SELECT * FROM \"index-" + otherUser + "\"", TimeValue.timeValueSeconds(10), user) + ); + assertThat(exc.getResponse().getStatusLine().getStatusCode(), equalTo(400)); + } + + // user with manage privilege can check status and delete + public void testWithManager() throws IOException { + Response submitResp = submitAsyncSqlSearch("SELECT event_type FROM \"index\" WHERE val=0", TimeValue.timeValueSeconds(10), "user1"); + assertOK(submitResp); + String id = extractResponseId(submitResp); + Response getResp = getAsyncSqlSearch(id, "user1"); + assertOK(getResp); + + Response getStatus = getAsyncSqlStatus(id, "manage_user"); + assertOK(getStatus); + Map status = BaseRestSqlTestCase.toMap(getStatus, null); + assertEquals(200, status.get("completion_status")); + + Response deleteResp = deleteAsyncSqlSearch(id, "manage_user"); + assertOK(deleteResp); + } + + static String extractResponseId(Response response) throws IOException { + Map map = toMap(response); + return (String) map.get("id"); + } + + static void index(String index, String id, Object... fields) throws IOException { + XContentBuilder document = jsonBuilder().startObject(); + for (int i = 0; i < fields.length; i += 2) { + document.field((String) fields[i], fields[i + 1]); + } + document.endObject(); + final Request request = new Request("POST", "/" + index + "/_doc/" + id); + request.setJsonEntity(Strings.toString(document)); + assertOK(client().performRequest(request)); + } + + static void refresh(String index) throws IOException { + assertOK(adminClient().performRequest(new Request("POST", "/" + index + "/_refresh"))); + } + + static Response get(String index, String id, String user) throws IOException { + final Request request = new Request("GET", "/" + index + "/_doc/" + id); + setRunAsHeader(request, user); + return client().performRequest(request); + } + + static Response submitAsyncSqlSearch(String query, TimeValue waitForCompletion, String user) throws IOException { + final Request request = new Request("POST", "/_sql"); + setRunAsHeader(request, user); + request.setJsonEntity( + Strings.toString( + JsonXContent.contentBuilder() + .startObject() + .field("query", query) + .field("wait_for_completion_timeout", waitForCompletion.toString()) + // we do the cleanup explicitly + .field("keep_on_completion", "true") + .endObject() + ) + ); + return client().performRequest(request); + } + + static Response getAsyncSqlSearch(String id, String user) throws IOException { + final Request request = new Request("GET", "/_sql/async/" + id); + setRunAsHeader(request, user); + request.addParameter("wait_for_completion_timeout", "0ms"); + request.addParameter("format", "json"); + return client().performRequest(request); + } + + static Response getAsyncSqlStatus(String id, String user) throws IOException { + final Request request = new Request("GET", "/_sql/async/status/" + id); + setRunAsHeader(request, user); + request.addParameter("format", "json"); + return client().performRequest(request); + } + + static Response deleteAsyncSqlSearch(String id, String user) throws IOException { + final Request request = new Request("DELETE", "/_sql/async/delete/" + id); + setRunAsHeader(request, user); + return client().performRequest(request); + } + + static Map toMap(Response response) throws IOException { + return toMap(EntityUtils.toString(response.getEntity())); + } + + static Map toMap(String response) { + return XContentHelper.convertToMap(JsonXContent.jsonXContent, response, false); + } + + /** + * Use es-security-runas-user to become a less privileged user. + */ + static void setRunAsHeader(Request request, String user) { + final RequestOptions.Builder builder = RequestOptions.DEFAULT.toBuilder(); + builder.addHeader(RUN_AS_USER_HEADER, user); + request.setOptions(builder); + } +} diff --git a/x-pack/plugin/sql/qa/server/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/RestSqlIT.java b/x-pack/plugin/sql/qa/server/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/RestSqlIT.java index aad2c7dffb0f2..c0a1a79e4c9a7 100644 --- a/x-pack/plugin/sql/qa/server/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/RestSqlIT.java +++ b/x-pack/plugin/sql/qa/server/single-node/src/test/java/org/elasticsearch/xpack/sql/qa/single_node/RestSqlIT.java @@ -85,7 +85,7 @@ public void testIncorrectAcceptHeader() throws IOException { request.setEntity(stringEntity); expectBadRequest( () -> toMap(client().performRequest(request), "plain"), - containsString("Invalid response content type: Accept=[application/fff]") + containsString("Invalid request content type: Accept=[application/fff]") ); } } diff --git a/x-pack/plugin/sql/qa/server/src/main/java/org/elasticsearch/xpack/sql/qa/rest/BaseRestSqlTestCase.java b/x-pack/plugin/sql/qa/server/src/main/java/org/elasticsearch/xpack/sql/qa/rest/BaseRestSqlTestCase.java index fdcb6f936b955..b40bfe2a2806f 100644 --- a/x-pack/plugin/sql/qa/server/src/main/java/org/elasticsearch/xpack/sql/qa/rest/BaseRestSqlTestCase.java +++ b/x-pack/plugin/sql/qa/server/src/main/java/org/elasticsearch/xpack/sql/qa/rest/BaseRestSqlTestCase.java @@ -24,17 +24,20 @@ import static org.elasticsearch.xpack.sql.proto.Protocol.BINARY_FORMAT_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.CLIENT_ID_NAME; -import static org.elasticsearch.xpack.sql.proto.Protocol.VERSION_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNAR_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.CURSOR_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.FETCH_SIZE_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.FIELD_MULTI_VALUE_LENIENCY_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.FILTER_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ALIVE_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ON_COMPLETION_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.MODE_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.QUERY_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.RUNTIME_MAPPINGS_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.TIME_ZONE_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.VERSION_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.WAIT_FOR_COMPLETION_TIMEOUT_NAME; public abstract class BaseRestSqlTestCase extends ESRestTestCase { @@ -107,6 +110,21 @@ public RequestObjectBuilder binaryFormat(Boolean binaryFormat) { return this; } + public RequestObjectBuilder waitForCompletionTimeout(String timeout) { + request.append(field(WAIT_FOR_COMPLETION_TIMEOUT_NAME, timeout)); + return this; + } + + public RequestObjectBuilder keepOnCompletion(Boolean keepOnCompletion) { + request.append(field(KEEP_ON_COMPLETION_NAME, keepOnCompletion)); + return this; + } + + public RequestObjectBuilder keepAlive(String keepAlive) { + request.append(field(KEEP_ALIVE_NAME, keepAlive)); + return this; + } + public RequestObjectBuilder fieldMultiValueLeniency(Boolean fieldMultiValueLeniency) { request.append(field(FIELD_MULTI_VALUE_LENIENCY_NAME, fieldMultiValueLeniency)); return this; diff --git a/x-pack/plugin/sql/qa/server/src/main/java/org/elasticsearch/xpack/sql/qa/rest/RestSqlTestCase.java b/x-pack/plugin/sql/qa/server/src/main/java/org/elasticsearch/xpack/sql/qa/rest/RestSqlTestCase.java index ab109d4ca6011..13906046e7af4 100644 --- a/x-pack/plugin/sql/qa/server/src/main/java/org/elasticsearch/xpack/sql/qa/rest/RestSqlTestCase.java +++ b/x-pack/plugin/sql/qa/server/src/main/java/org/elasticsearch/xpack/sql/qa/rest/RestSqlTestCase.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -31,6 +32,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; +import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.sql.JDBCType; import java.time.Instant; @@ -49,7 +51,23 @@ import static java.util.Collections.singletonList; import static java.util.Collections.singletonMap; import static java.util.Collections.unmodifiableMap; +import static org.elasticsearch.common.Strings.hasText; import static org.elasticsearch.xpack.ql.TestUtils.getNumberOfSearchContexts; +import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNS_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_ASYNC_ID; +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_ASYNC_PARTIAL; +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_ASYNC_RUNNING; +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_CURSOR; +import static org.elasticsearch.xpack.sql.proto.Protocol.ID_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.IS_PARTIAL_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.IS_RUNNING_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.ROWS_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_ASYNC_DELETE_REST_ENDPOINT; +import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_ASYNC_REST_ENDPOINT; +import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_ASYNC_STATUS_REST_ENDPOINT; +import static org.elasticsearch.xpack.sql.proto.Protocol.URL_PARAM_DELIMITER; +import static org.elasticsearch.xpack.sql.proto.Protocol.URL_PARAM_FORMAT; +import static org.elasticsearch.xpack.sql.proto.Protocol.WAIT_FOR_COMPLETION_TIMEOUT_NAME; import static org.hamcrest.Matchers.containsString; /** @@ -91,17 +109,10 @@ public void testBasicQuery() throws IOException { } public void testNextPage() throws IOException { - Request request = new Request("POST", "/test/_bulk"); - request.addParameter("refresh", "true"); - String mode = randomMode(); - StringBuilder bulk = new StringBuilder(); - for (int i = 0; i < 20; i++) { - bulk.append("{\"index\":{\"_id\":\"" + i + "\"}}\n"); - bulk.append("{\"text\":\"text" + i + "\", \"number\":" + i + "}\n"); - } - request.setJsonEntity(bulk.toString()); - client().performRequest(request); + final int count = 20; + bulkLoadTestData(count); + String mode = randomMode(); boolean columnar = randomBoolean(); String sqlRequest = query("SELECT text, number, SQRT(number) AS s, SCORE()" + " FROM test" + " ORDER BY number, SCORE()").mode( mode @@ -109,7 +120,7 @@ public void testNextPage() throws IOException { Number value = xContentDependentFloatingNumberValue(mode, 1f); String cursor = null; - for (int i = 0; i < 20; i += 2) { + for (int i = 0; i < count; i += 2) { Map response; if (i == 0) { response = runSql(new StringEntity(sqlRequest, ContentType.APPLICATION_JSON), "", mode); @@ -964,7 +975,7 @@ public void testDefaultQueryInCSV() throws IOException { Tuple response = runSqlAsText(query, "text/csv"); assertEquals(expected, response.v1()); - response = runSqlAsTextFormat(query, "csv"); + response = runSqlAsTextWithFormat(query, "csv"); assertEquals(expected, response.v1()); } @@ -1027,7 +1038,7 @@ public void testQueryInTSV() throws IOException { String query = "SELECT * FROM test ORDER BY number"; Tuple response = runSqlAsText(query, "text/tab-separated-values"); assertEquals(expected, response.v1()); - response = runSqlAsTextFormat(query, "tsv"); + response = runSqlAsTextWithFormat(query, "tsv"); assertEquals(expected, response.v1()); } @@ -1137,7 +1148,19 @@ private void executeQueryWithNextPage(String format, String expectedHeader, Stri assertEquals(0, getNumberOfSearchContexts(client(), "test")); } - private Tuple runSqlAsText(String sql, String accept) throws IOException { + private static void bulkLoadTestData(int count) throws IOException { + Request request = new Request("POST", "/test/_bulk"); + request.addParameter("refresh", "true"); + StringBuilder bulk = new StringBuilder(); + for (int i = 0; i < count; i++) { + bulk.append("{\"index\":{\"_id\":\"" + i + "\"}}\n"); + bulk.append("{\"text\":\"text" + i + "\", \"number\":" + i + "}\n"); + } + request.setJsonEntity(bulk.toString()); + client().performRequest(request); + } + + private static Tuple runSqlAsText(String sql, String accept) throws IOException { return runSqlAsText(StringUtils.EMPTY, new StringEntity(query(sql).toString(), ContentType.APPLICATION_JSON), accept); } @@ -1145,7 +1168,7 @@ private Tuple runSqlAsText(String sql, String accept) throws IOE * Run SQL as text using the {@code Accept} header to specify the format * rather than the {@code format} parameter. */ - private Tuple runSqlAsText(String suffix, HttpEntity entity, String accept) throws IOException { + private static Tuple runSqlAsText(String suffix, HttpEntity entity, String accept) throws IOException { Request request = new Request("POST", SQL_QUERY_REST_ENDPOINT + suffix); request.addParameter("error_trace", "true"); request.setEntity(entity); @@ -1153,27 +1176,25 @@ private Tuple runSqlAsText(String suffix, HttpEntity entity, Str options.addHeader("Accept", accept); request.setOptions(options); Response response = client().performRequest(request); - return new Tuple<>( - Streams.copyToString(new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8)), - response.getHeader("Cursor") - ); + return new Tuple<>(responseBody(response), response.getHeader("Cursor")); + } + + private static String responseBody(Response response) throws IOException { + return Streams.copyToString(new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8)); } /** * Run SQL as text using the {@code format} parameter to specify the format * rather than an {@code Accept} header. */ - private Tuple runSqlAsTextFormat(String sql, String format) throws IOException { + private static Tuple runSqlAsTextWithFormat(String sql, String format) throws IOException { Request request = new Request("POST", SQL_QUERY_REST_ENDPOINT); request.addParameter("error_trace", "true"); request.addParameter("format", format); request.setJsonEntity(query(sql).toString()); Response response = client().performRequest(request); - return new Tuple<>( - Streams.copyToString(new InputStreamReader(response.getEntity().getContent(), StandardCharsets.UTF_8)), - response.getHeader("Cursor") - ); + return new Tuple<>(responseBody(response), response.getHeader("Cursor")); } public static void assertResponse(Map expected, Map actual) { @@ -1183,4 +1204,207 @@ public static void assertResponse(Map expected, Map expected = new HashMap<>(); + expected.put(IS_PARTIAL_NAME, false); + expected.put(IS_RUNNING_NAME, false); + expected.put(COLUMNS_NAME, singletonList(columnInfo(mode, "1", "integer", JDBCType.INTEGER, 11))); + expected.put(ROWS_NAME, singletonList(singletonList(1))); + assertAsyncResponse(expected, runSql(builder, mode)); + } + + public void testAsyncTextWait() throws IOException { + RequestObjectBuilder builder = query("SELECT 1").waitForCompletionTimeout("1d").keepOnCompletion(false); + + Map contentMap = new HashMap<>() { + { + put("txt", " 1 \n---------------\n1 \n"); + put("csv", "1\r\n1\r\n"); + put("tsv", "1\n1\n"); + } + }; + + for (String format : contentMap.keySet()) { + Response response = runSqlAsTextWithFormat(builder, format); + + assertEquals(contentMap.get(format), responseBody(response)); + + assertTrue(hasText(response.getHeader(HEADER_NAME_ASYNC_ID))); + assertEquals("false", response.getHeader(HEADER_NAME_ASYNC_PARTIAL)); + assertEquals("false", response.getHeader(HEADER_NAME_ASYNC_RUNNING)); + } + } + + public void testAsyncTextPaginated() throws IOException, InterruptedException { + final Map acceptMap = new HashMap<>() { + { + put("txt", "text/plain"); + put("csv", "text/csv"); + put("tsv", "text/tab-separated-values"); + } + }; + final int fetchSize = randomIntBetween(1, 10); + final int fetchCount = randomIntBetween(1, 9); + bulkLoadTestData(fetchSize * fetchCount); // NB: product needs to stay below 100, for txt format tests + + String format = randomFrom(acceptMap.keySet()); + String mode = randomMode(); + String cursor = null; + for (int i = 0; i <= fetchCount; i++) { // the last iteration (the equality in `<=`) checks on no-cursor & no-results + // start the query + RequestObjectBuilder builder = (hasText(cursor) ? cursor(cursor) : query("SELECT text, number FROM test")).fetchSize(fetchSize) + .waitForCompletionTimeout("0d") // don't wait at all + .keepOnCompletion(true) + .keepAlive("1d") // keep "forever" + .mode(mode) + .binaryFormat(false); // prevent JDBC mode to (ignore `format` and) enforce CBOR + Response response = runSqlAsTextWithFormat(builder, format); + + Character csvDelimiter = ','; + + assertEquals(200, response.getStatusLine().getStatusCode()); + assertTrue(response.getHeader(HEADER_NAME_ASYNC_PARTIAL).equals(response.getHeader(HEADER_NAME_ASYNC_RUNNING))); + String asyncId = response.getHeader(HEADER_NAME_ASYNC_ID); + assertTrue(hasText(asyncId)); + + // it happens though rarely that the whole response comes through despite the given 0-wait + if (response.getHeader(HEADER_NAME_ASYNC_PARTIAL).equals("true")) { + + // potentially wait for it to complete + boolean pollForCompletion = randomBoolean(); + if (pollForCompletion) { + Request request = new Request("GET", SQL_ASYNC_STATUS_REST_ENDPOINT + asyncId); + Map asyncStatus = null; + long millis = 1; + for (boolean isRunning = true; isRunning; Thread.sleep(millis *= 2)) { + asyncStatus = toMap(client().performRequest(request), null); + isRunning = (boolean) asyncStatus.get(IS_RUNNING_NAME); + } + assertEquals(200, (int) asyncStatus.get("completion_status")); + assertEquals(asyncStatus.get(IS_RUNNING_NAME), asyncStatus.get(IS_PARTIAL_NAME)); + assertEquals(asyncId, asyncStatus.get(ID_NAME)); + } + + // fetch the results (potentially waiting now to complete) + Request request = new Request("GET", SQL_ASYNC_REST_ENDPOINT + asyncId); + if (pollForCompletion == false) { + request.addParameter(WAIT_FOR_COMPLETION_TIMEOUT_NAME, "1d"); + } + if (randomBoolean()) { + request.addParameter(URL_PARAM_FORMAT, format); + if (format.equals("csv")) { + csvDelimiter = ';'; + request.addParameter(URL_PARAM_DELIMITER, URLEncoder.encode(String.valueOf(csvDelimiter), StandardCharsets.UTF_8)); + } + } else { + request.setOptions(request.getOptions().toBuilder().addHeader("Accept", acceptMap.get(format))); + } + response = client().performRequest(request); + + assertEquals(200, response.getStatusLine().getStatusCode()); + assertEquals(asyncId, response.getHeader(HEADER_NAME_ASYNC_ID)); + assertEquals("false", response.getHeader(HEADER_NAME_ASYNC_PARTIAL)); + assertEquals("false", response.getHeader(HEADER_NAME_ASYNC_RUNNING)); + } + + cursor = response.getHeader(HEADER_NAME_CURSOR); + String body = responseBody(response); + if (i == fetchCount) { + assertNull(cursor); + assertFalse(hasText(body)); + } else { + String expected = expectedTextBody(format, fetchSize, i, csvDelimiter); + assertEquals(expected, body); + + if (hasText(cursor) == false) { // depending on index and fetch size, the last page might or not have a cursor + assertEquals(i, fetchCount - 1); + i++; // end the loop after deleting the async resources + } + } + + // delete the query results + Request request = new Request("DELETE", SQL_ASYNC_DELETE_REST_ENDPOINT + asyncId); + Map deleteStatus = toMap(client().performRequest(request), null); + assertEquals(200, response.getStatusLine().getStatusCode()); + assertTrue((boolean) deleteStatus.get("acknowledged")); + } + } + + static Map runSql(RequestObjectBuilder builder, String mode) throws IOException { + return toMap(runSql(builder.mode(mode)), mode); + } + + static Response runSql(RequestObjectBuilder builder) throws IOException { + return runSqlAsTextWithFormat(builder, null); + } + + static Response runSqlAsTextWithFormat(RequestObjectBuilder builder, @Nullable String format) throws IOException { + Request request = new Request("POST", SQL_QUERY_REST_ENDPOINT); + request.addParameter("error_trace", "true"); // Helps with debugging in case something crazy happens on the server. + request.addParameter("pretty", "true"); // Improves error reporting readability + if (format != null) { + request.addParameter(URL_PARAM_FORMAT, format); // Improves error reporting readability + } + request.setEntity(new StringEntity(builder.toString(), ContentType.APPLICATION_JSON)); + return client().performRequest(request); + + } + + static void assertAsyncResponse(Map expected, Map actual) { + String actualId = (String) actual.get("id"); + assertTrue("async ID missing in response", hasText(actualId)); + expected.put("id", actualId); + assertResponse(expected, actual); + } + + private static String expectedTextBody(String format, int fetchSize, int count, Character csvDelimiter) { + StringBuilder sb = new StringBuilder(); + if (count == 0) { // add the header + switch (format) { + case "txt": + sb.append(" text | number \n"); + sb.append("---------------+---------------\n"); + break; + case "csv": + sb.append("text").append(csvDelimiter).append("number\r\n"); + break; + case "tsv": + sb.append("text\tnumber\n"); + break; + default: + assert false : "unexpected format type [" + format + "]"; + } + } + for (int i = 0; i < fetchSize; i++) { + int val = fetchSize * count + i; + sb.append("text").append(val); + switch (format) { + case "txt": + sb.append(val < 10 ? " " : StringUtils.EMPTY).append(" |"); + break; + case "csv": + sb.append(csvDelimiter); + break; + case "tsv": + sb.append('\t'); + break; + } + sb.append(val); + if (format.equals("txt")) { + sb.append(" ").append(val < 10 ? " " : StringUtils.EMPTY); + } + sb.append(format.equals("csv") ? "\r\n" : "\n"); + } + return sb.toString(); + } } diff --git a/x-pack/plugin/sql/sql-action/build.gradle b/x-pack/plugin/sql/sql-action/build.gradle index f9db4e7026114..42136ee3e1d91 100644 --- a/x-pack/plugin/sql/sql-action/build.gradle +++ b/x-pack/plugin/sql/sql-action/build.gradle @@ -18,6 +18,8 @@ dependencies { api(project(':libs:elasticsearch-x-content')) { transitive = false } + api xpackProject('plugin:core') + api xpackProject('plugin:ql') api xpackProject('plugin:sql:sql-proto') api "org.apache.lucene:lucene-core:${versions.lucene}" api "joda-time:joda-time:${versions.joda}" @@ -137,4 +139,4 @@ tasks.named("thirdPartyAudit").configure { 'org.zeromq.ZMQ$Socket', 'org.zeromq.ZMQ' ) -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryRequest.java b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryRequest.java index 6cbcb1760b4cb..06e72ee1c903b 100644 --- a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryRequest.java +++ b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryRequest.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.sql.action; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.xcontent.ParseField; import org.elasticsearch.common.Strings; @@ -16,6 +17,8 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xpack.sql.proto.Protocol; import org.elasticsearch.xpack.sql.proto.RequestInfo; import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue; @@ -29,8 +32,15 @@ import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.xpack.sql.proto.Protocol.BINARY_FORMAT_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNAR_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.DEFAULT_KEEP_ALIVE; +import static org.elasticsearch.xpack.sql.proto.Protocol.DEFAULT_KEEP_ON_COMPLETION; +import static org.elasticsearch.xpack.sql.proto.Protocol.DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT; import static org.elasticsearch.xpack.sql.proto.Protocol.FIELD_MULTI_VALUE_LENIENCY_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.INDEX_INCLUDE_FROZEN_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ALIVE_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ON_COMPLETION_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.MIN_KEEP_ALIVE; +import static org.elasticsearch.xpack.sql.proto.Protocol.WAIT_FOR_COMPLETION_TIMEOUT_NAME; /** * Request to perform an sql query @@ -41,6 +51,9 @@ public class SqlQueryRequest extends AbstractSqlQueryRequest { static final ParseField FIELD_MULTI_VALUE_LENIENCY = new ParseField(FIELD_MULTI_VALUE_LENIENCY_NAME); static final ParseField INDEX_INCLUDE_FROZEN = new ParseField(INDEX_INCLUDE_FROZEN_NAME); static final ParseField BINARY_COMMUNICATION = new ParseField(BINARY_FORMAT_NAME); + static final ParseField WAIT_FOR_COMPLETION_TIMEOUT = new ParseField(WAIT_FOR_COMPLETION_TIMEOUT_NAME); + static final ParseField KEEP_ON_COMPLETION = new ParseField(KEEP_ON_COMPLETION_NAME); + static final ParseField KEEP_ALIVE = new ParseField(KEEP_ALIVE_NAME); static { PARSER.declareString(SqlQueryRequest::cursor, CURSOR); @@ -48,6 +61,12 @@ public class SqlQueryRequest extends AbstractSqlQueryRequest { PARSER.declareBoolean(SqlQueryRequest::fieldMultiValueLeniency, FIELD_MULTI_VALUE_LENIENCY); PARSER.declareBoolean(SqlQueryRequest::indexIncludeFrozen, INDEX_INCLUDE_FROZEN); PARSER.declareBoolean(SqlQueryRequest::binaryCommunication, BINARY_COMMUNICATION); + PARSER.declareField(SqlQueryRequest::waitForCompletionTimeout, + (p, c) -> TimeValue.parseTimeValue(p.text(), WAIT_FOR_COMPLETION_TIMEOUT_NAME), WAIT_FOR_COMPLETION_TIMEOUT, + ObjectParser.ValueType.VALUE); + PARSER.declareBoolean(SqlQueryRequest::keepOnCompletion, KEEP_ON_COMPLETION); + PARSER.declareField(SqlQueryRequest::keepAlive, + (p, c) -> TimeValue.parseTimeValue(p.text(), KEEP_ALIVE_NAME), KEEP_ALIVE, ObjectParser.ValueType.VALUE); } private String cursor = ""; @@ -62,24 +81,33 @@ public class SqlQueryRequest extends AbstractSqlQueryRequest { private boolean fieldMultiValueLeniency = Protocol.FIELD_MULTI_VALUE_LENIENCY; private boolean indexIncludeFrozen = Protocol.INDEX_INCLUDE_FROZEN; + // Async settings + private TimeValue waitForCompletionTimeout = DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT; + private boolean keepOnCompletion = DEFAULT_KEEP_ON_COMPLETION; + private TimeValue keepAlive = DEFAULT_KEEP_ALIVE; + public SqlQueryRequest() { super(); } public SqlQueryRequest(String query, List params, QueryBuilder filter, Map runtimeMappings, ZoneId zoneId, int fetchSize, TimeValue requestTimeout, TimeValue pageTimeout, Boolean columnar, - String cursor, RequestInfo requestInfo, boolean fieldMultiValueLeniency, boolean indexIncludeFrozen) { + String cursor, RequestInfo requestInfo, boolean fieldMultiValueLeniency, boolean indexIncludeFrozen, + TimeValue waitForCompletionTimeout, boolean keepOnCompletion, TimeValue keepAlive) { super(query, params, filter, runtimeMappings, zoneId, fetchSize, requestTimeout, pageTimeout, requestInfo); this.cursor = cursor; this.columnar = columnar; this.fieldMultiValueLeniency = fieldMultiValueLeniency; this.indexIncludeFrozen = indexIncludeFrozen; + this.waitForCompletionTimeout = waitForCompletionTimeout; + this.keepOnCompletion = keepOnCompletion; + this.keepAlive = keepAlive; } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = super.validate(); - if ((false == Strings.hasText(query())) && Strings.hasText(cursor) == false) { + if (Strings.hasText(query()) == false && Strings.hasText(cursor) == false) { validationException = addValidationError("one of [query] or [cursor] is required", validationException); } return validationException; @@ -146,6 +174,42 @@ public Boolean binaryCommunication() { return binaryCommunication; } + public SqlQueryRequest waitForCompletionTimeout(TimeValue waitForCompletionTimeout) { + this.waitForCompletionTimeout = waitForCompletionTimeout; + return this; + } + + public TimeValue waitForCompletionTimeout() { + return waitForCompletionTimeout; + } + + public SqlQueryRequest keepOnCompletion(boolean keepOnCompletion) { + this.keepOnCompletion = keepOnCompletion; + return this; + } + + public boolean keepOnCompletion() { + return keepOnCompletion; + } + + public SqlQueryRequest keepAlive(TimeValue keepAlive) { + if (keepAlive != null && keepAlive.getMillis() < MIN_KEEP_ALIVE.getMillis()) { + throw new IllegalArgumentException("[" + KEEP_ALIVE_NAME + "] must be greater than " + MIN_KEEP_ALIVE + ", got: " + keepAlive); + } + this.keepAlive = keepAlive; + return this; + } + + public TimeValue keepAlive() { + return keepAlive; + } + + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SqlQueryTask(id, type, action, getDescription(), parentTaskId, headers, null, null, keepAlive, + mode(), version(), columnar()); + } + public SqlQueryRequest(StreamInput in) throws IOException { super(in); cursor = in.readString(); @@ -153,6 +217,11 @@ public SqlQueryRequest(StreamInput in) throws IOException { fieldMultiValueLeniency = in.readBoolean(); indexIncludeFrozen = in.readBoolean(); binaryCommunication = in.readOptionalBoolean(); + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO: V_7_14_0 + this.waitForCompletionTimeout = in.readOptionalTimeValue(); + this.keepOnCompletion = in.readBoolean(); + this.keepAlive = in.readOptionalTimeValue(); + } } @Override @@ -163,11 +232,17 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(fieldMultiValueLeniency); out.writeBoolean(indexIncludeFrozen); out.writeOptionalBoolean(binaryCommunication); + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { // TODO: V_7_14_0 + out.writeOptionalTimeValue(waitForCompletionTimeout); + out.writeBoolean(keepOnCompletion); + out.writeOptionalTimeValue(keepAlive); + } } @Override public int hashCode() { - return Objects.hash(super.hashCode(), cursor, columnar, fieldMultiValueLeniency, indexIncludeFrozen, binaryCommunication); + return Objects.hash(super.hashCode(), cursor, columnar, fieldMultiValueLeniency, indexIncludeFrozen, binaryCommunication, + waitForCompletionTimeout, keepOnCompletion, keepAlive); } @Override @@ -177,7 +252,10 @@ public boolean equals(Object obj) { && Objects.equals(columnar, ((SqlQueryRequest) obj).columnar) && fieldMultiValueLeniency == ((SqlQueryRequest) obj).fieldMultiValueLeniency && indexIncludeFrozen == ((SqlQueryRequest) obj).indexIncludeFrozen - && binaryCommunication == ((SqlQueryRequest) obj).binaryCommunication; + && binaryCommunication == ((SqlQueryRequest) obj).binaryCommunication + && Objects.equals(waitForCompletionTimeout, ((SqlQueryRequest) obj).waitForCompletionTimeout) + && keepOnCompletion == ((SqlQueryRequest) obj).keepOnCompletion + && Objects.equals(keepAlive, ((SqlQueryRequest) obj).keepAlive); } @Override @@ -190,7 +268,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws // This is needed just to test round-trip compatibility with proto.SqlQueryRequest return new org.elasticsearch.xpack.sql.proto.SqlQueryRequest(query(), params(), zoneId(), fetchSize(), requestTimeout(), pageTimeout(), filter(), columnar(), cursor(), requestInfo(), fieldMultiValueLeniency(), indexIncludeFrozen(), - binaryCommunication(), runtimeMappings()).toXContent(builder, params); + binaryCommunication(), runtimeMappings(), waitForCompletionTimeout(), keepOnCompletion(), keepAlive()) + .toXContent(builder, params); } public static SqlQueryRequest fromXContent(XContentParser parser) { diff --git a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryRequestBuilder.java b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryRequestBuilder.java index c91a5aceedd7b..a1a0f296d3b9f 100644 --- a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryRequestBuilder.java +++ b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryRequestBuilder.java @@ -30,15 +30,18 @@ public class SqlQueryRequestBuilder extends ActionRequestBuilder params, QueryBuilder filter, Map runtimeMappings, ZoneId zoneId, int fetchSize, TimeValue requestTimeout, TimeValue pageTimeout, boolean columnar, String nextPageInfo, RequestInfo requestInfo, - boolean multiValueFieldLeniency, boolean indexIncludeFrozen) { + boolean multiValueFieldLeniency, boolean indexIncludeFrozen, TimeValue waitForCompletionTimeout, boolean keepOnCompletion, + TimeValue keepAlive) { super(client, action, new SqlQueryRequest(query, params, filter, runtimeMappings, zoneId, fetchSize, requestTimeout, pageTimeout, - columnar, nextPageInfo, requestInfo, multiValueFieldLeniency, indexIncludeFrozen)); + columnar, nextPageInfo, requestInfo, multiValueFieldLeniency, indexIncludeFrozen, waitForCompletionTimeout, + keepOnCompletion, keepAlive)); } public SqlQueryRequestBuilder query(String query) { @@ -105,4 +108,19 @@ public SqlQueryRequestBuilder multiValueFieldLeniency(boolean lenient) { request.fieldMultiValueLeniency(lenient); return this; } + + public SqlQueryRequestBuilder waitForCompletionTimeout(TimeValue waitForCompletionTimeout) { + request.waitForCompletionTimeout(waitForCompletionTimeout); + return this; + } + + public SqlQueryRequestBuilder keepOnCompletion(boolean keepOnCompletion) { + request.keepOnCompletion(keepOnCompletion); + return this; + } + + public SqlQueryRequestBuilder keepAlive(TimeValue keepAlive) { + request.keepAlive(keepAlive); + return this; + } } diff --git a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryResponse.java b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryResponse.java index 3685eec83dd9b..bfbd110688eb9 100644 --- a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryResponse.java +++ b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryResponse.java @@ -19,8 +19,10 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.ql.async.QlStatusResponse; import org.elasticsearch.xpack.sql.proto.ColumnInfo; import org.elasticsearch.xpack.sql.proto.Mode; +import org.elasticsearch.xpack.sql.proto.Protocol; import org.elasticsearch.xpack.sql.proto.SqlVersion; import org.elasticsearch.xpack.sql.proto.StringUtils; @@ -35,7 +37,7 @@ /** * Response to perform an sql query */ -public class SqlQueryResponse extends ActionResponse implements ToXContentObject { +public class SqlQueryResponse extends ActionResponse implements ToXContentObject, QlStatusResponse.AsyncStatus { // TODO: Simplify cursor handling private String cursor; @@ -46,6 +48,10 @@ public class SqlQueryResponse extends ActionResponse implements ToXContentObject // TODO investigate reusing Page here - it probably is much more efficient private List> rows; private static final String INTERVAL_CLASS_NAME = "Interval"; + // async + private final @Nullable String asyncExecutionId; + private final boolean isPartial; + private final boolean isRunning; public SqlQueryResponse(StreamInput in) throws IOException { super(in); @@ -75,6 +81,10 @@ public SqlQueryResponse(StreamInput in) throws IOException { } } this.rows = unmodifiableList(rows); + columnar = in.readBoolean(); + asyncExecutionId = in.readOptionalString(); + isPartial = in.readBoolean(); + isRunning = in.readBoolean(); } public SqlQueryResponse( @@ -83,7 +93,10 @@ public SqlQueryResponse( SqlVersion sqlVersion, boolean columnar, @Nullable List columns, - List> rows + List> rows, + @Nullable String asyncExecutionId, + boolean isPartial, + boolean isRunning ) { this.cursor = cursor; this.mode = mode; @@ -91,6 +104,20 @@ public SqlQueryResponse( this.columnar = columnar; this.columns = columns; this.rows = rows; + this.asyncExecutionId = asyncExecutionId; + this.isPartial = isPartial; + this.isRunning = isRunning; + } + + public SqlQueryResponse( + String cursor, + Mode mode, + SqlVersion sqlVersion, + boolean columnar, + @Nullable List columns, + List> rows + ) { + this(cursor, mode, sqlVersion, columnar, columns, rows, null, false, false); } /** @@ -157,12 +184,22 @@ public void writeTo(StreamOutput out) throws IOException { } } } + out.writeBoolean(columnar); + out.writeOptionalString(asyncExecutionId); + out.writeBoolean(isPartial); + out.writeBoolean(isRunning); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); { + if (hasId()) { + builder.field(Protocol.ID_NAME, asyncExecutionId); + builder.field(Protocol.IS_PARTIAL_NAME, isPartial); + builder.field(Protocol.IS_RUNNING_NAME, isRunning); + } + if (columns != null) { builder.startArray("columns"); { @@ -248,6 +285,25 @@ public static void writeColumnInfo(StreamOutput out, ColumnInfo columnInfo) thro out.writeOptionalVInt(columnInfo.displaySize()); } + public boolean hasId() { + return Strings.hasText(asyncExecutionId); + } + + @Override + public String id() { + return asyncExecutionId; + } + + @Override + public boolean isRunning() { + return isRunning; + } + + @Override + public boolean isPartial() { + return isPartial; + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryTask.java b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryTask.java new file mode 100644 index 0000000000000..710f5309a49aa --- /dev/null +++ b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlQueryTask.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.sql.action; + +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.xpack.core.async.AsyncExecutionId; +import org.elasticsearch.xpack.core.async.StoredAsyncTask; +import org.elasticsearch.xpack.sql.proto.Mode; +import org.elasticsearch.xpack.sql.proto.SqlVersion; + +import java.util.Map; + +import static java.util.Collections.emptyList; + +public class SqlQueryTask extends StoredAsyncTask { + + private final Mode mode; + private final SqlVersion sqlVersion; + private final boolean columnar; + + public SqlQueryTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers, + Map originHeaders, AsyncExecutionId asyncExecutionId, TimeValue keepAlive, Mode mode, + SqlVersion sqlVersion, boolean columnar) { + super(id, type, action, description, parentTaskId, headers, originHeaders, asyncExecutionId, keepAlive); + this.mode = mode; + this.sqlVersion = sqlVersion; + this.columnar = columnar; + } + + @Override + public SqlQueryResponse getCurrentResult() { + // for Ql searches we never store a search response in the task (neither partial, nor final) + // we kill the task on final response, so if the task is still present, it means the search is still running + // NB: the schema is only returned in the actual first (and currently last) response to the query + return new SqlQueryResponse("", mode, sqlVersion, columnar, null, emptyList(), + getExecutionId().getEncoded(), true, true); + } +} diff --git a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlTranslateRequest.java b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlTranslateRequest.java index fe47e6327c74c..bbd4cc5088931 100644 --- a/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlTranslateRequest.java +++ b/x-pack/plugin/sql/sql-action/src/main/java/org/elasticsearch/xpack/sql/action/SqlTranslateRequest.java @@ -75,6 +75,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws false, false, null, - runtimeMappings()).toXContent(builder, params); + runtimeMappings(), + null, + false, + null).toXContent(builder, params); } } diff --git a/x-pack/plugin/sql/sql-action/src/test/java/org/elasticsearch/xpack/sql/action/SqlQueryRequestTests.java b/x-pack/plugin/sql/sql-action/src/test/java/org/elasticsearch/xpack/sql/action/SqlQueryRequestTests.java index 9fa1e7ddfc133..5a475b05e42c7 100644 --- a/x-pack/plugin/sql/sql-action/src/test/java/org/elasticsearch/xpack/sql/action/SqlQueryRequestTests.java +++ b/x-pack/plugin/sql/sql-action/src/test/java/org/elasticsearch/xpack/sql/action/SqlQueryRequestTests.java @@ -43,6 +43,9 @@ import static org.elasticsearch.xpack.sql.proto.Protocol.FIELD_MULTI_VALUE_LENIENCY_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.FILTER_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.INDEX_INCLUDE_FROZEN_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ALIVE_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ON_COMPLETION_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.MIN_KEEP_ALIVE; import static org.elasticsearch.xpack.sql.proto.Protocol.MODE_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.PAGE_TIMEOUT_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_NAME; @@ -53,6 +56,7 @@ import static org.elasticsearch.xpack.sql.proto.Protocol.RUNTIME_MAPPINGS_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.TIME_ZONE_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.VERSION_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.WAIT_FOR_COMPLETION_TIMEOUT_NAME; import static org.elasticsearch.xpack.sql.proto.RequestInfo.CLIENT_IDS; public class SqlQueryRequestTests extends AbstractWireSerializingTestCase { @@ -82,7 +86,7 @@ protected SqlQueryRequest createTestInstance() { return new SqlQueryRequest(randomAlphaOfLength(10), randomParameters(), SqlTestUtils.randomFilterOrNull(random()), randomRuntimeMappings(), randomZone(), between(1, Integer.MAX_VALUE), randomTV(), randomTV(), randomBoolean(), randomAlphaOfLength(10), requestInfo, - randomBoolean(), randomBoolean() + randomBoolean(), randomBoolean(), randomTV(), randomBoolean(), randomTVGreaterThan(MIN_KEEP_ALIVE) ); } @@ -103,12 +107,16 @@ protected SqlQueryRequest mutateInstance(SqlQueryRequest instance) { request -> request.requestTimeout(randomValueOtherThan(request.requestTimeout(), this::randomTV)), request -> request.filter(randomValueOtherThan(request.filter(), () -> request.filter() == null ? randomFilter(random()) : randomFilterOrNull(random()))), - request -> request.columnar(randomValueOtherThan(request.columnar(), () -> randomBoolean())), - request -> request.cursor(randomValueOtherThan(request.cursor(), SqlQueryResponseTests::randomStringCursor)) + request -> request.columnar(randomValueOtherThan(request.columnar(), ESTestCase::randomBoolean)), + request -> request.cursor(randomValueOtherThan(request.cursor(), SqlQueryResponseTests::randomStringCursor)), + request -> request.waitForCompletionTimeout(randomValueOtherThan(request.waitForCompletionTimeout(), this::randomTV)), + request -> request.keepOnCompletion(randomValueOtherThan(request.keepOnCompletion(), ESTestCase::randomBoolean)), + request -> request.keepAlive(randomValueOtherThan(request.keepAlive(), () -> randomTVGreaterThan(MIN_KEEP_ALIVE))) ); SqlQueryRequest newRequest = new SqlQueryRequest(instance.query(), instance.params(), instance.filter(), instance.runtimeMappings(), instance.zoneId(), instance.fetchSize(), instance.requestTimeout(), instance.pageTimeout(), instance.columnar(), - instance.cursor(), instance.requestInfo(), instance.fieldMultiValueLeniency(), instance.indexIncludeFrozen()); + instance.cursor(), instance.requestInfo(), instance.fieldMultiValueLeniency(), instance.indexIncludeFrozen(), + instance.waitForCompletionTimeout(), instance.keepOnCompletion(), instance.keepAlive()); mutator.accept(newRequest); return newRequest; } @@ -155,6 +163,14 @@ private TimeValue randomTV() { return TimeValue.parseTimeValue(randomTimeValue(), null, "test"); } + private TimeValue randomTVGreaterThan(TimeValue min) { + TimeValue value; + do { + value = randomTV(); + } while (value.getMillis() < min.getMillis()); + return value; + } + public List randomParameters() { if (randomBoolean()) { return Collections.emptyList(); @@ -247,6 +263,15 @@ private static void toXContent(SqlQueryRequest request, XContentBuilder builder) if (request.runtimeMappings() != null) { builder.field(RUNTIME_MAPPINGS_NAME, request.runtimeMappings()); } + if (request.waitForCompletionTimeout() != null) { + builder.field(WAIT_FOR_COMPLETION_TIMEOUT_NAME, request.waitForCompletionTimeout().getStringRep()); + } + if (request.keepOnCompletion()) { + builder.field(KEEP_ON_COMPLETION_NAME, request.keepOnCompletion()); + } + if (request.keepAlive() != null) { + builder.field(KEEP_ALIVE_NAME, request.keepAlive().getStringRep()); + } builder.endObject(); } } diff --git a/x-pack/plugin/sql/sql-action/src/test/java/org/elasticsearch/xpack/sql/action/SqlQueryResponseTests.java b/x-pack/plugin/sql/sql-action/src/test/java/org/elasticsearch/xpack/sql/action/SqlQueryResponseTests.java index d3e14f5a00a52..cabfdfebec2bf 100644 --- a/x-pack/plugin/sql/sql-action/src/test/java/org/elasticsearch/xpack/sql/action/SqlQueryResponseTests.java +++ b/x-pack/plugin/sql/sql-action/src/test/java/org/elasticsearch/xpack/sql/action/SqlQueryResponseTests.java @@ -28,6 +28,9 @@ import static org.elasticsearch.common.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.sql.action.AbstractSqlQueryRequest.CURSOR; +import static org.elasticsearch.xpack.sql.proto.Protocol.ID_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.IS_PARTIAL_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.IS_RUNNING_NAME; import static org.elasticsearch.xpack.sql.proto.SqlVersion.DATE_NANOS_SUPPORT_VERSION; import static org.hamcrest.Matchers.hasSize; @@ -39,7 +42,8 @@ static String randomStringCursor() { @Override protected SqlQueryResponse createTestInstance() { - return createRandomInstance(randomStringCursor(), randomFrom(Mode.values()), randomBoolean()); + return createRandomInstance(randomStringCursor(), randomFrom(Mode.values()), randomBoolean(), + rarely() ? null : randomAlphaOfLength(100), randomBoolean(), randomBoolean()); } @Override @@ -47,7 +51,8 @@ protected Writeable.Reader instanceReader() { return SqlQueryResponse::new; } - public static SqlQueryResponse createRandomInstance(String cursor, Mode mode, boolean columnar) { + public static SqlQueryResponse createRandomInstance(String cursor, Mode mode, boolean columnar, String asyncExecutionId, + boolean isPartial, boolean isRunning) { int columnCount = between(1, 10); List columns = null; @@ -84,7 +89,7 @@ public static SqlQueryResponse createRandomInstance(String cursor, Mode mode, bo rows.add(row); } } - return new SqlQueryResponse(cursor, mode, DATE_NANOS_SUPPORT_VERSION, false, columns, rows); + return new SqlQueryResponse(cursor, mode, DATE_NANOS_SUPPORT_VERSION, false, columns, rows, asyncExecutionId, isPartial, isRunning); } public void testToXContent() throws IOException { @@ -125,12 +130,19 @@ public void testToXContent() throws IOException { if (testInstance.cursor().equals("") == false) { assertEquals(rootMap.get(CURSOR.getPreferredName()), testInstance.cursor()); } + + if (Strings.hasText(testInstance.id())) { + assertEquals(testInstance.id(), rootMap.get(ID_NAME)); + assertEquals(testInstance.isPartial(), rootMap.get(IS_PARTIAL_NAME)); + assertEquals(testInstance.isRunning(), rootMap.get(IS_RUNNING_NAME)); + } } @Override protected SqlQueryResponse doParseInstance(XContentParser parser) { org.elasticsearch.xpack.sql.proto.SqlQueryResponse response = org.elasticsearch.xpack.sql.proto.SqlQueryResponse.fromXContent(parser); - return new SqlQueryResponse(response.cursor(), Mode.JDBC, DATE_NANOS_SUPPORT_VERSION, false, response.columns(), response.rows()); + return new SqlQueryResponse(response.cursor(), Mode.JDBC, DATE_NANOS_SUPPORT_VERSION, false, response.columns(), response.rows(), + response.id(), response.isPartial(), response.isRunning()); } } diff --git a/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/Protocol.java b/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/Protocol.java index 1b73485ab3eec..b6ae31a9e5b2a 100644 --- a/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/Protocol.java +++ b/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/Protocol.java @@ -34,6 +34,11 @@ public final class Protocol { public static final String FIELD_MULTI_VALUE_LENIENCY_NAME = "field_multi_value_leniency"; public static final String INDEX_INCLUDE_FROZEN_NAME = "index_include_frozen"; public static final String RUNTIME_MAPPINGS_NAME = "runtime_mappings"; + // async + public static final String WAIT_FOR_COMPLETION_TIMEOUT_NAME = "wait_for_completion_timeout"; + public static final String KEEP_ON_COMPLETION_NAME = "keep_on_completion"; + public static final String KEEP_ALIVE_NAME = "keep_alive"; + // params public static final String PARAMS_NAME = "params"; public static final String PARAMS_TYPE_NAME = "type"; @@ -41,6 +46,10 @@ public final class Protocol { // responses public static final String COLUMNS_NAME = "columns"; public static final String ROWS_NAME = "rows"; + // responses async + public static final String ID_NAME = "id"; + public static final String IS_PARTIAL_NAME = "is_partial"; + public static final String IS_RUNNING_NAME = "is_running"; public static final ZoneId TIME_ZONE = ZoneId.of("Z"); @@ -61,12 +70,26 @@ public final class Protocol { public static final Boolean COLUMNAR = Boolean.FALSE; public static final Boolean BINARY_COMMUNICATION = null; + public static final TimeValue DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT = null; + public static final Boolean DEFAULT_KEEP_ON_COMPLETION = false; + public static TimeValue DEFAULT_KEEP_ALIVE = TimeValue.timeValueDays(5); + public static TimeValue MIN_KEEP_ALIVE = TimeValue.timeValueMinutes(1); + /* * URL parameters */ public static final String URL_PARAM_FORMAT = "format"; public static final String URL_PARAM_DELIMITER = "delimiter"; + /** + * HTTP header names + */ + public static final String HEADER_NAME_CURSOR = "Cursor"; + public static final String HEADER_NAME_TOOK_NANOS = "Took-nanos"; + public static final String HEADER_NAME_ASYNC_ID = "Async-ID"; + public static final String HEADER_NAME_ASYNC_PARTIAL = "Async-partial"; + public static final String HEADER_NAME_ASYNC_RUNNING = "Async-running"; + /** * SQL-related endpoints */ @@ -74,4 +97,8 @@ public final class Protocol { public static final String SQL_QUERY_REST_ENDPOINT = "/_sql"; public static final String SQL_TRANSLATE_REST_ENDPOINT = "/_sql/translate"; public static final String SQL_STATS_REST_ENDPOINT = "/_sql/stats"; + // async + public static final String SQL_ASYNC_REST_ENDPOINT = "/_sql/async/"; + public static final String SQL_ASYNC_STATUS_REST_ENDPOINT = SQL_ASYNC_REST_ENDPOINT + "status/"; + public static final String SQL_ASYNC_DELETE_REST_ENDPOINT = SQL_ASYNC_REST_ENDPOINT + "delete/"; } diff --git a/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/SqlQueryRequest.java b/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/SqlQueryRequest.java index 31fd0ec87e4b9..bea151703cb4b 100644 --- a/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/SqlQueryRequest.java +++ b/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/SqlQueryRequest.java @@ -28,6 +28,8 @@ import static org.elasticsearch.xpack.sql.proto.Protocol.FIELD_MULTI_VALUE_LENIENCY_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.FILTER_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.INDEX_INCLUDE_FROZEN_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ALIVE_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ON_COMPLETION_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.MODE_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.PAGE_TIMEOUT_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.PARAMS_NAME; @@ -36,6 +38,7 @@ import static org.elasticsearch.xpack.sql.proto.Protocol.RUNTIME_MAPPINGS_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.TIME_ZONE_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.VERSION_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.WAIT_FOR_COMPLETION_TIMEOUT_NAME; /** * Sql query request for JDBC/CLI client @@ -57,11 +60,16 @@ public class SqlQueryRequest extends AbstractSqlRequest { private final Boolean binaryCommunication; @Nullable private final Map runtimeMappings; + // Async settings + private final TimeValue waitForCompletionTimeout; + private final boolean keepOnCompletion; + private final TimeValue keepAlive; public SqlQueryRequest(String query, List params, ZoneId zoneId, int fetchSize, TimeValue requestTimeout, TimeValue pageTimeout, ToXContent filter, Boolean columnar, String cursor, RequestInfo requestInfo, boolean fieldMultiValueLeniency, boolean indexIncludeFrozen, - Boolean binaryCommunication, Map runtimeMappings) { + Boolean binaryCommunication, Map runtimeMappings, TimeValue waitForCompletionTimeout, + boolean keepOnCompletion, TimeValue keepAlive) { super(requestInfo); this.query = query; this.params = params; @@ -76,8 +84,19 @@ public SqlQueryRequest(String query, List params, ZoneId zon this.indexIncludeFrozen = indexIncludeFrozen; this.binaryCommunication = binaryCommunication; this.runtimeMappings = runtimeMappings; + this.waitForCompletionTimeout = waitForCompletionTimeout; + this.keepOnCompletion = keepOnCompletion; + this.keepAlive = keepAlive; } + public SqlQueryRequest(String query, List params, ZoneId zoneId, int fetchSize, + TimeValue requestTimeout, TimeValue pageTimeout, ToXContent filter, Boolean columnar, + String cursor, RequestInfo requestInfo, boolean fieldMultiValueLeniency, boolean indexIncludeFrozen, + Boolean binaryCommunication, Map runtimeMappings) { + this(query, params, zoneId, fetchSize, requestTimeout, pageTimeout, filter, columnar, cursor, requestInfo, fieldMultiValueLeniency, + indexIncludeFrozen, binaryCommunication, runtimeMappings, Protocol.DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT, + Protocol.DEFAULT_KEEP_ON_COMPLETION, Protocol.DEFAULT_KEEP_ALIVE); + } public SqlQueryRequest(String cursor, TimeValue requestTimeout, TimeValue pageTimeout, RequestInfo requestInfo, boolean binaryCommunication) { this("", emptyList(), Protocol.TIME_ZONE, Protocol.FETCH_SIZE, requestTimeout, pageTimeout, null, false, @@ -166,6 +185,18 @@ public Map runtimeMappings() { return runtimeMappings; } + public TimeValue waitForCompletionTimeout() { + return waitForCompletionTimeout; + } + + public boolean keepOnCompletion() { + return keepOnCompletion; + } + + public TimeValue keepAlive() { + return keepAlive; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -190,13 +221,17 @@ public boolean equals(Object o) { && fieldMultiValueLeniency == that.fieldMultiValueLeniency && indexIncludeFrozen == that.indexIncludeFrozen && Objects.equals(binaryCommunication, that.binaryCommunication) - && Objects.equals(runtimeMappings, that.runtimeMappings); + && Objects.equals(runtimeMappings, that.runtimeMappings) + && Objects.equals(waitForCompletionTimeout, that.waitForCompletionTimeout) + && keepOnCompletion == that.keepOnCompletion + && Objects.equals(keepAlive, that.keepAlive); } @Override public int hashCode() { return Objects.hash(super.hashCode(), query, zoneId, fetchSize, requestTimeout, pageTimeout, - filter, columnar, cursor, fieldMultiValueLeniency, indexIncludeFrozen, binaryCommunication, runtimeMappings); + filter, columnar, cursor, fieldMultiValueLeniency, indexIncludeFrozen, binaryCommunication, runtimeMappings, + waitForCompletionTimeout, keepOnCompletion, keepAlive); } @Override @@ -252,6 +287,15 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (runtimeMappings.isEmpty() == false) { builder.field(RUNTIME_MAPPINGS_NAME, runtimeMappings); } + if (waitForCompletionTimeout != null) { + builder.field(WAIT_FOR_COMPLETION_TIMEOUT_NAME, waitForCompletionTimeout.getStringRep()); + } + if (keepOnCompletion) { + builder.field(KEEP_ON_COMPLETION_NAME, keepOnCompletion); + } + if (keepAlive != null) { + builder.field(KEEP_ALIVE_NAME, keepAlive.getStringRep()); + } return builder; } } diff --git a/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/SqlQueryResponse.java b/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/SqlQueryResponse.java index 2db7db0352fff..20d1da1e976da 100644 --- a/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/SqlQueryResponse.java +++ b/x-pack/plugin/sql/sql-proto/src/main/java/org/elasticsearch/xpack/sql/proto/SqlQueryResponse.java @@ -21,6 +21,9 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.sql.proto.Protocol.COLUMNS_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.CURSOR_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.ID_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.IS_PARTIAL_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.IS_RUNNING_NAME; import static org.elasticsearch.xpack.sql.proto.Protocol.ROWS_NAME; /** @@ -33,16 +36,25 @@ public class SqlQueryResponse { objects -> new SqlQueryResponse( objects[0] == null ? "" : (String) objects[0], (List) objects[1], - (List>) objects[2])); + (List>) objects[2], + (String) objects[3], + objects[4] != null && (boolean) objects[4], + objects[5] != null && (boolean) objects[5])); public static final ParseField CURSOR = new ParseField(CURSOR_NAME); public static final ParseField COLUMNS = new ParseField(COLUMNS_NAME); public static final ParseField ROWS = new ParseField(ROWS_NAME); + public static final ParseField ID = new ParseField(ID_NAME); + public static final ParseField IS_PARTIAL = new ParseField(IS_PARTIAL_NAME); + public static final ParseField IS_RUNNING = new ParseField(IS_RUNNING_NAME); static { PARSER.declareString(optionalConstructorArg(), CURSOR); PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> ColumnInfo.fromXContent(p), COLUMNS); PARSER.declareField(constructorArg(), (p, c) -> parseRows(p), ROWS, ValueType.OBJECT_ARRAY); + PARSER.declareString(optionalConstructorArg(), ID); + PARSER.declareBoolean(optionalConstructorArg(), IS_PARTIAL); + PARSER.declareBoolean(optionalConstructorArg(), IS_RUNNING); } // TODO: Simplify cursor handling @@ -50,11 +62,23 @@ public class SqlQueryResponse { private final List columns; // TODO investigate reusing Page here - it probably is much more efficient private final List> rows; + // async + private final @Nullable String asyncExecutionId; + private final boolean isPartial; + private final boolean isRunning; public SqlQueryResponse(String cursor, @Nullable List columns, List> rows) { + this(cursor, columns, rows, null, false, false); + } + + public SqlQueryResponse(String cursor, @Nullable List columns, List> rows, String asyncExecutionId, + boolean isPartial, boolean isRunning) { this.cursor = cursor; this.columns = columns; this.rows = rows; + this.asyncExecutionId = asyncExecutionId; + this.isPartial = isPartial; + this.isRunning = isRunning; } /** @@ -77,6 +101,18 @@ public List> rows() { return rows; } + public String id() { + return asyncExecutionId; + } + + public boolean isPartial() { + return isPartial; + } + + public boolean isRunning() { + return isRunning; + } + public static SqlQueryResponse fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } @@ -114,12 +150,15 @@ public boolean equals(Object o) { SqlQueryResponse that = (SqlQueryResponse) o; return Objects.equals(cursor, that.cursor) && Objects.equals(columns, that.columns) && - Objects.equals(rows, that.rows); + Objects.equals(rows, that.rows) && + Objects.equals(asyncExecutionId, that.asyncExecutionId) && + isPartial == that.isPartial && + isRunning == that.isRunning; } @Override public int hashCode() { - return Objects.hash(cursor, columns, rows); + return Objects.hash(cursor, columns, rows, asyncExecutionId, isPartial, isRunning); } } diff --git a/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/AbstractSqlBlockingIntegTestCase.java b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/AbstractSqlBlockingIntegTestCase.java new file mode 100644 index 0000000000000..64935d69b56a4 --- /dev/null +++ b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/AbstractSqlBlockingIntegTestCase.java @@ -0,0 +1,288 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.sql.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; +import org.elasticsearch.action.fieldcaps.FieldCapabilitiesAction; +import org.elasticsearch.action.support.ActionFilter; +import org.elasticsearch.action.support.ActionFilterChain; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.IndexModule; +import org.elasticsearch.index.shard.SearchOperationListener; +import org.elasticsearch.license.LicenseService; +import org.elasticsearch.plugins.ActionPlugin; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.PluginsService; +import org.elasticsearch.search.internal.ReaderContext; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xpack.core.XPackSettings; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +import static org.elasticsearch.test.ESIntegTestCase.Scope.SUITE; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; + +/** + * IT tests that can block SQL execution at different places + */ +@ESIntegTestCase.ClusterScope(scope = SUITE, numDataNodes = 0, numClientNodes = 0, maxNumDataNodes = 0) +public abstract class AbstractSqlBlockingIntegTestCase extends ESIntegTestCase { + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings)); + settings.put(XPackSettings.SECURITY_ENABLED.getKey(), false); + settings.put(XPackSettings.WATCHER_ENABLED.getKey(), false); + settings.put(XPackSettings.GRAPH_ENABLED.getKey(), false); + settings.put(XPackSettings.MACHINE_LEARNING_ENABLED.getKey(), false); + settings.put(LicenseService.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial"); + return settings.build(); + } + + @Override + protected Collection> nodePlugins() { + return Arrays.asList(LocalStateSQLXPackPlugin.class, SearchBlockPlugin.class); + } + + protected List initBlockFactory(boolean searchBlock, boolean fieldCapsBlock) { + List plugins = new ArrayList<>(); + for (PluginsService pluginsService : internalCluster().getInstances(PluginsService.class)) { + plugins.addAll(pluginsService.filterPlugins(SearchBlockPlugin.class)); + } + for (SearchBlockPlugin plugin : plugins) { + plugin.reset(); + if (searchBlock) { + plugin.enableSearchBlock(); + } + if (fieldCapsBlock) { + plugin.enableFieldCapBlock(); + } + } + return plugins; + } + + protected void disableBlocks(List plugins) { + disableFieldCapBlocks(plugins); + disableSearchBlocks(plugins); + } + + protected void disableSearchBlocks(List plugins) { + for (SearchBlockPlugin plugin : plugins) { + plugin.disableSearchBlock(); + } + } + + protected void disableFieldCapBlocks(List plugins) { + for (SearchBlockPlugin plugin : plugins) { + plugin.disableFieldCapBlock(); + } + } + + protected void awaitForBlockedSearches(List plugins, String index) throws Exception { + int numberOfShards = getNumShards(index).numPrimaries; + assertBusy(() -> { + int numberOfBlockedPlugins = getNumberOfContexts(plugins); + logger.trace("The plugin blocked on {} out of {} shards", numberOfBlockedPlugins, numberOfShards); + assertThat(numberOfBlockedPlugins, greaterThan(0)); + }); + } + + protected int getNumberOfContexts(List plugins) throws Exception { + int count = 0; + for (SearchBlockPlugin plugin : plugins) { + count += plugin.contexts.get(); + } + return count; + } + + protected int getNumberOfFieldCaps(List plugins) throws Exception { + int count = 0; + for (SearchBlockPlugin plugin : plugins) { + count += plugin.fieldCaps.get(); + } + return count; + } + + protected void awaitForBlockedFieldCaps(List plugins) throws Exception { + assertBusy(() -> { + int numberOfBlockedPlugins = getNumberOfFieldCaps(plugins); + logger.trace("The plugin blocked on {} nodes", numberOfBlockedPlugins); + assertThat(numberOfBlockedPlugins, greaterThan(0)); + }); + } + + public static class SearchBlockPlugin extends Plugin implements ActionPlugin { + protected final Logger logger = LogManager.getLogger(getClass()); + + private final AtomicInteger contexts = new AtomicInteger(); + + private final AtomicInteger fieldCaps = new AtomicInteger(); + + private final AtomicBoolean shouldBlockOnSearch = new AtomicBoolean(false); + + private final AtomicBoolean shouldBlockOnFieldCapabilities = new AtomicBoolean(false); + + private final String nodeId; + + private final ExecutorService executorService = Executors.newFixedThreadPool(1); + + public void reset() { + contexts.set(0); + fieldCaps.set(0); + } + + public void disableSearchBlock() { + shouldBlockOnSearch.set(false); + } + + public void enableSearchBlock() { + shouldBlockOnSearch.set(true); + } + + + public void disableFieldCapBlock() { + shouldBlockOnFieldCapabilities.set(false); + } + + public void enableFieldCapBlock() { + shouldBlockOnFieldCapabilities.set(true); + } + + public SearchBlockPlugin(Settings settings, Path configPath) throws Exception { + nodeId = settings.get("node.name"); + } + + @Override + public void onIndexModule(IndexModule indexModule) { + super.onIndexModule(indexModule); + indexModule.addSearchOperationListener(new SearchOperationListener() { + @Override + public void onNewReaderContext(ReaderContext readerContext) { + contexts.incrementAndGet(); + try { + logger.trace("blocking search on " + nodeId); + assertBusy(() -> assertFalse(shouldBlockOnSearch.get())); + logger.trace("unblocking search on " + nodeId); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }); + } + + @Override + public List getActionFilters() { + List list = new ArrayList<>(); + list.add(new ActionFilter() { + @Override + public int order() { + return 0; + } + + @Override + public void apply( + Task task, + String action, + Request request, + ActionListener listener, + ActionFilterChain chain) { + + if (action.equals(FieldCapabilitiesAction.NAME)) { + final Consumer actionWrapper = resp -> { + try { + fieldCaps.incrementAndGet(); + logger.trace("blocking field caps on " + nodeId); + assertBusy(() -> assertFalse(shouldBlockOnFieldCapabilities.get())); + logger.trace("unblocking field caps on " + nodeId); + } catch (Exception e) { + throw new RuntimeException(e); + } finally { + listener.onResponse(resp); + } + logger.trace("unblocking field caps on " + nodeId); + }; + final Thread originalThread = Thread.currentThread(); + chain.proceed(task, action, request, + ActionListener.wrap( + resp -> { + if (originalThread == Thread.currentThread()) { + // async if we never exited the original thread + executorService.execute(() -> actionWrapper.accept(resp)); + } else { + actionWrapper.accept(resp); + } + }, + listener::onFailure) + ); + } else { + chain.proceed(task, action, request, listener); + } + } + }); + return list; + } + + @Override + public void close() throws IOException { + List runnables = executorService.shutdownNow(); + assertTrue(runnables.isEmpty()); + } + } + + protected TaskId findTaskWithXOpaqueId(String id, String action) { + TaskInfo taskInfo = getTaskInfoWithXOpaqueId(id, action); + if (taskInfo != null) { + return taskInfo.getTaskId(); + } else { + return null; + } + } + + protected TaskInfo getTaskInfoWithXOpaqueId(String id, String action) { + ListTasksResponse tasks = client().admin().cluster().prepareListTasks().setActions(action).get(); + for (TaskInfo task : tasks.getTasks()) { + if (id.equals(task.getHeaders().get(Task.X_OPAQUE_ID))) { + return task; + } + } + return null; + } + + protected TaskId cancelTaskWithXOpaqueId(String id, String action) { + TaskId taskId = findTaskWithXOpaqueId(id, action); + assertNotNull(taskId); + logger.trace("Cancelling task " + taskId); + CancelTasksResponse response = client().admin().cluster().prepareCancelTasks().setTaskId(taskId).get(); + assertThat(response.getTasks(), hasSize(1)); + assertThat(response.getTasks().get(0).getAction(), equalTo(action)); + logger.trace("Task is cancelled " + taskId); + return taskId; + } + +} diff --git a/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/AsyncSqlSearchActionIT.java b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/AsyncSqlSearchActionIT.java new file mode 100644 index 0000000000000..7b4c851d1e628 --- /dev/null +++ b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/AsyncSqlSearchActionIT.java @@ -0,0 +1,324 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.sql.action; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.NoShardAvailableActionException; +import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.ByteBufferStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.script.MockScriptPlugin; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.xpack.core.XPackPlugin; +import org.elasticsearch.xpack.core.async.AsyncExecutionId; +import org.elasticsearch.xpack.core.async.DeleteAsyncResultAction; +import org.elasticsearch.xpack.core.async.DeleteAsyncResultRequest; +import org.elasticsearch.xpack.core.async.GetAsyncResultRequest; +import org.elasticsearch.xpack.core.async.StoredAsyncResponse; +import org.elasticsearch.xpack.sql.plugin.SqlAsyncGetResultsAction; +import org.elasticsearch.xpack.sql.proto.Protocol; +import org.junit.After; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.Function; + +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertFutureThrows; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class AsyncSqlSearchActionIT extends AbstractSqlBlockingIntegTestCase { + + private final ExecutorService executorService = Executors.newFixedThreadPool(1); + + NamedWriteableRegistry registry = new NamedWriteableRegistry(new SearchModule(Settings.EMPTY, + Collections.emptyList()).getNamedWriteables()); + + /** + * Shutdown the executor so we don't leak threads into other test runs. + */ + @After + public void shutdownExec() { + executorService.shutdown(); + } + + private void prepareIndex() throws Exception { + assertAcked(client().admin().indices().prepareCreate("test") + .setMapping("val", "type=integer", "event_type", "type=keyword", "@timestamp", "type=date", "i", "type=integer") + .get()); + createIndex("idx_unmapped"); + + int numDocs = randomIntBetween(6, 20); + + List builders = new ArrayList<>(); + + for (int i = 0; i < numDocs; i++) { + int fieldValue = randomIntBetween(0, 10); + builders.add(client().prepareIndex("test").setSource( + jsonBuilder().startObject() + .field("val", fieldValue) + .field("event_type", "my_event") + .field("@timestamp", "2020-04-09T12:35:48Z") + .field("i", i) + .endObject())); + } + indexRandom(true, builders); + } + + public void testBasicAsyncExecution() throws Exception { + prepareIndex(); + + boolean success = randomBoolean(); + String query = "SELECT event_type FROM test WHERE " + (success ? "i=1" : "10/i=1"); + SqlQueryRequestBuilder builder = new SqlQueryRequestBuilder(client(), SqlQueryAction.INSTANCE) + .query(query).waitForCompletionTimeout(TimeValue.timeValueMillis(1)); + + List plugins = initBlockFactory(true, false); + + logger.trace("Starting async search"); + SqlQueryResponse response = client().execute(SqlQueryAction.INSTANCE, builder.request()).get(); + assertThat(response.isRunning(), is(true)); + assertThat(response.isPartial(), is(true)); + assertThat(response.id(), notNullValue()); + + logger.trace("Waiting for block to be established"); + awaitForBlockedSearches(plugins, "test"); + logger.trace("Block is established"); + + if (randomBoolean()) { + // let's timeout first + GetAsyncResultRequest getResultsRequest = new GetAsyncResultRequest(response.id()) + .setKeepAlive(TimeValue.timeValueMinutes(10)) + .setWaitForCompletionTimeout(TimeValue.timeValueMillis(10)); + SqlQueryResponse responseWithTimeout = client().execute(SqlAsyncGetResultsAction.INSTANCE, getResultsRequest).get(); + assertThat(responseWithTimeout.isRunning(), is(true)); + assertThat(responseWithTimeout.isPartial(), is(true)); + assertThat(responseWithTimeout.id(), equalTo(response.id())); + } + + // Now we wait + GetAsyncResultRequest getResultsRequest = new GetAsyncResultRequest(response.id()) + .setKeepAlive(TimeValue.timeValueMinutes(10)) + .setWaitForCompletionTimeout(TimeValue.timeValueSeconds(10)); + ActionFuture future = client().execute(SqlAsyncGetResultsAction.INSTANCE, getResultsRequest); + disableBlocks(plugins); + if (success) { + response = future.get(); + assertThat(response, notNullValue()); + assertThat(response.rows().size(), equalTo(1)); + } else { + Exception ex = expectThrows(Exception.class, future::actionGet); + assertThat(ex.getCause().getMessage(), containsString("by zero")); + } + AcknowledgedResponse deleteResponse = + client().execute(DeleteAsyncResultAction.INSTANCE, new DeleteAsyncResultRequest(response.id())).actionGet(); + assertThat(deleteResponse.isAcknowledged(), equalTo(true)); + } + + public void testGoingAsync() throws Exception { + prepareIndex(); + + boolean success = randomBoolean(); + String query = "SELECT event_type FROM test WHERE " + (success ? "i=1" : "10/i=1"); + SqlQueryRequestBuilder builder = new SqlQueryRequestBuilder(client(), SqlQueryAction.INSTANCE) + .query(query).waitForCompletionTimeout(TimeValue.timeValueMillis(1)); + + boolean customKeepAlive = randomBoolean(); + TimeValue keepAliveValue; + if (customKeepAlive) { + keepAliveValue = TimeValue.parseTimeValue(randomTimeValue(1, 5, "d"), "test"); + builder.keepAlive(keepAliveValue); + } else { + keepAliveValue = Protocol.DEFAULT_KEEP_ALIVE; + } + + List plugins = initBlockFactory(true, false); + + String opaqueId = randomAlphaOfLength(10); + logger.trace("Starting async search"); + SqlQueryResponse response = client().filterWithHeader(Collections.singletonMap(Task.X_OPAQUE_ID, opaqueId)) + .execute(SqlQueryAction.INSTANCE, builder.request()).get(); + assertThat(response.isRunning(), is(true)); + assertThat(response.isPartial(), is(true)); + assertThat(response.id(), notNullValue()); + + logger.trace("Waiting for block to be established"); + awaitForBlockedSearches(plugins, "test"); + logger.trace("Block is established"); + + String id = response.id(); + TaskId taskId = findTaskWithXOpaqueId(opaqueId, SqlQueryAction.NAME + "[a]"); + assertThat(taskId, notNullValue()); + + disableBlocks(plugins); + + assertBusy(() -> assertThat(findTaskWithXOpaqueId(opaqueId, SqlQueryAction.NAME + "[a]"), nullValue())); + StoredAsyncResponse doc = getStoredRecord(id); + // Make sure that the expiration time is not more than 1 min different from the current time + keep alive + assertThat(System.currentTimeMillis() + keepAliveValue.getMillis() - doc.getExpirationTime(), + lessThan(doc.getExpirationTime() + TimeValue.timeValueMinutes(1).getMillis())); + if (success) { + assertThat(doc.getException(), nullValue()); + assertThat(doc.getResponse(), notNullValue()); + assertThat(doc.getResponse().rows().size(), equalTo(1)); + } else { + assertThat(doc.getException(), notNullValue()); + assertThat(doc.getResponse(), nullValue()); + assertThat(doc.getException().getCause().getMessage(), containsString("by zero")); + } + } + + public void testAsyncCancellation() throws Exception { + prepareIndex(); + + boolean success = randomBoolean(); + String query = "SELECT event_type FROM test WHERE " + (success ? "i=1" : "10/i=1"); + SqlQueryRequestBuilder builder = new SqlQueryRequestBuilder(client(), SqlQueryAction.INSTANCE) + .query(query).waitForCompletionTimeout(TimeValue.timeValueMillis(1)); + + boolean customKeepAlive = randomBoolean(); + final TimeValue keepAliveValue; + if (customKeepAlive) { + keepAliveValue = TimeValue.parseTimeValue(randomTimeValue(1, 5, "d"), "test"); + builder.keepAlive(keepAliveValue); + } + + List plugins = initBlockFactory(true, false); + + String opaqueId = randomAlphaOfLength(10); + logger.trace("Starting async search"); + SqlQueryResponse response = client().filterWithHeader(Collections.singletonMap(Task.X_OPAQUE_ID, opaqueId)) + .execute(SqlQueryAction.INSTANCE, builder.request()).get(); + assertThat(response.isRunning(), is(true)); + assertThat(response.isPartial(), is(true)); + assertThat(response.id(), notNullValue()); + + logger.trace("Waiting for block to be established"); + awaitForBlockedSearches(plugins, "test"); + logger.trace("Block is established"); + + ActionFuture deleteResponse = + client().execute(DeleteAsyncResultAction.INSTANCE, new DeleteAsyncResultRequest(response.id())); + disableBlocks(plugins); + assertThat(deleteResponse.actionGet().isAcknowledged(), equalTo(true)); + + deleteResponse = client().execute(DeleteAsyncResultAction.INSTANCE, new DeleteAsyncResultRequest(response.id())); + assertFutureThrows(deleteResponse, ResourceNotFoundException.class); + } + + public void testFinishingBeforeTimeout() throws Exception { + prepareIndex(); + + boolean success = randomBoolean(); + boolean keepOnCompletion = randomBoolean(); + String query = "SELECT event_type FROM test WHERE " + (success ? "i=1" : "10/i=1"); + SqlQueryRequestBuilder builder = new SqlQueryRequestBuilder(client(), SqlQueryAction.INSTANCE) + .query(query).waitForCompletionTimeout(TimeValue.timeValueSeconds(10)); + if (keepOnCompletion || randomBoolean()) { + builder.keepOnCompletion(keepOnCompletion); + } + SqlQueryRequest request = builder.request(); + + if (success) { + SqlQueryResponse response = client().execute(SqlQueryAction.INSTANCE, request).get(); + assertThat(response.isRunning(), is(false)); + assertThat(response.isPartial(), is(false)); + assertThat(response.id(), notNullValue()); + assertThat(response.rows().size(), equalTo(1)); + if (keepOnCompletion) { + StoredAsyncResponse doc = getStoredRecord(response.id()); + assertThat(doc, notNullValue()); + assertThat(doc.getException(), nullValue()); + assertThat(doc.getResponse(), notNullValue()); + assertThat(doc.getResponse().rows().size(), equalTo(1)); + SqlQueryResponse storedResponse = client().execute(SqlAsyncGetResultsAction.INSTANCE, + new GetAsyncResultRequest(response.id())).actionGet(); + assertThat(storedResponse, equalTo(response)); + + AcknowledgedResponse deleteResponse = + client().execute(DeleteAsyncResultAction.INSTANCE, new DeleteAsyncResultRequest(response.id())).actionGet(); + assertThat(deleteResponse.isAcknowledged(), equalTo(true)); + } + } else { + Exception ex = expectThrows(Exception.class, + () -> client().execute(SqlQueryAction.INSTANCE, request).get()); + assertThat(ex.getMessage(), containsString("by zero")); + } + } + + + public StoredAsyncResponse getStoredRecord(String id) throws Exception { + try { + GetResponse doc = client().prepareGet(XPackPlugin.ASYNC_RESULTS_INDEX, AsyncExecutionId.decode(id).getDocId()).get(); + if (doc.isExists()) { + String value = doc.getSource().get("result").toString(); + try (ByteBufferStreamInput buf = new ByteBufferStreamInput(ByteBuffer.wrap(Base64.getDecoder().decode(value)))) { + try (StreamInput in = new NamedWriteableAwareStreamInput(buf, registry)) { + in.setVersion(Version.readVersion(in)); + return new StoredAsyncResponse<>(SqlQueryResponse::new, in); + } + } + } + return null; + } catch (IndexNotFoundException | NoShardAvailableActionException ex) { + return null; + } + } + + public static class FakePainlessScriptPlugin extends MockScriptPlugin { + + @Override + protected Map, Object>> pluginScripts() { + Map, Object>> scripts = new HashMap<>(); + scripts.put("InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(InternalSqlScriptUtils.div(" + + "params.v0,InternalQlScriptUtils.docValue(doc,params.v1)),params.v2))", FakePainlessScriptPlugin::fail); + return scripts; + } + + public static Object fail(Map arg) { + throw new ArithmeticException("Division by zero"); + } + + public String pluginScriptLang() { + // Faking painless + return "painless"; + } + } + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.appendToCopy(super.nodePlugins(), FakePainlessScriptPlugin.class); + } +} diff --git a/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/RestSqlCancellationIT.java b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/RestSqlCancellationIT.java new file mode 100644 index 0000000000000..cda6fdb3ec3fb --- /dev/null +++ b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/RestSqlCancellationIT.java @@ -0,0 +1,171 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.sql.action; + +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.client.Cancellable; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.network.NetworkModule; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.test.junit.annotations.TestLogging; +import org.elasticsearch.transport.Netty4Plugin; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.nio.NioTransportPlugin; +import org.elasticsearch.xpack.sql.proto.Protocol; +import org.junit.BeforeClass; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class RestSqlCancellationIT extends AbstractSqlBlockingIntegTestCase { + + private static String nodeHttpTypeKey; + + @BeforeClass + public static void setUpTransport() { + nodeHttpTypeKey = getHttpTypeKey(randomFrom(Netty4Plugin.class, NioTransportPlugin.class)); + } + + @Override + protected boolean addMockHttpTransport() { + return false; // enable http + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put(NetworkModule.HTTP_TYPE_KEY, nodeHttpTypeKey).build(); + } + + private static String getHttpTypeKey(Class clazz) { + if (clazz.equals(NioTransportPlugin.class)) { + return NioTransportPlugin.NIO_HTTP_TRANSPORT_NAME; + } else { + assert clazz.equals(Netty4Plugin.class); + return Netty4Plugin.NETTY_HTTP_TRANSPORT_NAME; + } + } + + @Override + protected Collection> nodePlugins() { + List> plugins = new ArrayList<>(super.nodePlugins()); + plugins.add(getTestTransportPlugin()); + plugins.add(Netty4Plugin.class); + plugins.add(NioTransportPlugin.class); + return plugins; + } + + @TestLogging(value = "org.elasticsearch.xpack.sql:TRACE", reason = "debug") + public void testRestCancellation() throws Exception { + assertAcked(client().admin().indices().prepareCreate("test") + .setMapping("val", "type=integer", "event_type", "type=keyword", "@timestamp", "type=date") + .get()); + createIndex("idx_unmapped"); + + int numDocs = randomIntBetween(6, 20); + + List builders = new ArrayList<>(); + + for (int i = 0; i < numDocs; i++) { + int fieldValue = randomIntBetween(0, 10); + builders.add(client().prepareIndex("test").setSource( + jsonBuilder().startObject() + .field("val", fieldValue).field("event_type", "my_event").field("@timestamp", "2020-04-09T12:35:48Z") + .endObject())); + } + + indexRandom(true, builders); + + // We are cancelling during both mapping and searching but we cancel during mapping so we should never reach the second block + List plugins = initBlockFactory(true, true); + SqlQueryRequest sqlRequest = new SqlQueryRequestBuilder(client(), SqlQueryAction.INSTANCE) + .query("SELECT event_type FROM test WHERE val=1").request(); + String id = randomAlphaOfLength(10); + + Request request = new Request("POST", Protocol.SQL_QUERY_REST_ENDPOINT); + request.setJsonEntity(Strings.toString(sqlRequest)); + request.setOptions(RequestOptions.DEFAULT.toBuilder().addHeader(Task.X_OPAQUE_ID, id)); + logger.trace("Preparing search"); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference error = new AtomicReference<>(); + Cancellable cancellable = getRestClient().performRequestAsync(request, new ResponseListener() { + @Override + public void onSuccess(Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exception) { + error.set(exception); + latch.countDown(); + } + }); + + logger.trace("Waiting for block to be established"); + awaitForBlockedFieldCaps(plugins); + logger.trace("Block is established"); + TaskInfo blockedTaskInfo = getTaskInfoWithXOpaqueId(id, SqlQueryAction.NAME); + assertThat(blockedTaskInfo, notNullValue()); + cancellable.cancel(); + logger.trace("Request is cancelled"); + + assertBusy(() -> { + for (TransportService transportService : internalCluster().getInstances(TransportService.class)) { + if (transportService.getLocalNode().getId().equals(blockedTaskInfo.getTaskId().getNodeId())) { + Task task = transportService.getTaskManager().getTask(blockedTaskInfo.getId()); + if (task != null) { + assertThat(task, instanceOf(SqlQueryTask.class)); + SqlQueryTask sqlSearchTask = (SqlQueryTask) task; + logger.trace("Waiting for cancellation to be propagated: {} ", sqlSearchTask.isCancelled()); + assertThat(sqlSearchTask.isCancelled(), equalTo(true)); + } + return; + } + } + fail("Task not found"); + }); + + logger.trace("Disabling field cap blocks"); + disableFieldCapBlocks(plugins); + // The task should be cancelled before ever reaching search blocks + assertBusy(() -> { + assertThat(getTaskInfoWithXOpaqueId(id, SqlQueryAction.NAME), nullValue()); + }); + // Make sure it didn't reach search blocks + assertThat(getNumberOfContexts(plugins), equalTo(0)); + disableSearchBlocks(plugins); + + latch.await(); + assertThat(error.get(), instanceOf(CancellationException.class)); + } + + @Override + protected boolean ignoreExternalCluster() { + return true; + } +} diff --git a/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/SqlCancellationIT.java b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/SqlCancellationIT.java new file mode 100644 index 0000000000000..ae91d10d74f0e --- /dev/null +++ b/x-pack/plugin/sql/src/internalClusterTest/java/org/elasticsearch/xpack/sql/action/SqlCancellationIT.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.sql.action; + +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; +import org.junit.After; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class SqlCancellationIT extends AbstractSqlBlockingIntegTestCase { + + private final ExecutorService executorService = Executors.newFixedThreadPool(1); + + /** + * Shutdown the executor so we don't leak threads into other test runs. + */ + @After + public void shutdownExec() { + executorService.shutdown(); + } + + public void testCancellation() throws Exception { + assertAcked(client().admin().indices().prepareCreate("test") + .setMapping("val", "type=integer", "event_type", "type=keyword", "@timestamp", "type=date") + .get()); + createIndex("idx_unmapped"); + + int numDocs = randomIntBetween(6, 20); + + List builders = new ArrayList<>(); + + for (int i = 0; i < numDocs; i++) { + int fieldValue = randomIntBetween(0, 10); + builders.add(client().prepareIndex("test").setSource( + jsonBuilder().startObject() + .field("val", fieldValue).field("event_type", "my_event").field("@timestamp", "2020-04-09T12:35:48Z") + .endObject())); + } + + indexRandom(true, builders); + boolean cancelDuringSearch = randomBoolean(); + List plugins = initBlockFactory(cancelDuringSearch, cancelDuringSearch == false); + SqlQueryRequest request = new SqlQueryRequestBuilder(client(), SqlQueryAction.INSTANCE) + .query("SELECT event_type FROM test WHERE val=1").request(); + String id = randomAlphaOfLength(10); + logger.trace("Preparing search"); + // We might perform field caps on the same thread if it is local client, so we cannot use the standard mechanism + Future future = executorService.submit(() -> + client().filterWithHeader(Collections.singletonMap(Task.X_OPAQUE_ID, id)).execute(SqlQueryAction.INSTANCE, request).get() + ); + logger.trace("Waiting for block to be established"); + if (cancelDuringSearch) { + awaitForBlockedSearches(plugins, "test"); + } else { + awaitForBlockedFieldCaps(plugins); + } + logger.trace("Block is established"); + cancelTaskWithXOpaqueId(id, SqlQueryAction.NAME); + + disableBlocks(plugins); + Exception exception = expectThrows(Exception.class, future::get); + Throwable inner = ExceptionsHelper.unwrap(exception, SearchPhaseExecutionException.class); + if (cancelDuringSearch) { + // Make sure we cancelled inside search + assertNotNull(inner); + assertThat(inner, instanceOf(SearchPhaseExecutionException.class)); + assertThat(inner.getCause(), instanceOf(TaskCancelledException.class)); + } else { + // Make sure we were not cancelled inside search + assertNull(inner); + assertThat(getNumberOfContexts(plugins), equalTo(0)); + Throwable cancellationException = ExceptionsHelper.unwrap(exception, TaskCancelledException.class); + assertNotNull(cancellationException); + } + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/execution/PlanExecutor.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/execution/PlanExecutor.java index b5a7ccb1194f4..b5ff6f4e4cc4f 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/execution/PlanExecutor.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/execution/PlanExecutor.java @@ -118,6 +118,14 @@ public void cleanCursor(SqlConfiguration cfg, Cursor cursor, ActionListener output, QueryContainer query, String index, Ac l = new ScrollActionListener(listener, client, cfg, output, query); } + if (cfg.task() != null && cfg.task().isCancelled()) { + listener.onFailure(new TaskCancelledException("cancelled")); + return; + } client.search(search, l); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncDeleteResultsAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncDeleteResultsAction.java new file mode 100644 index 0000000000000..0810a99e569c4 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncDeleteResultsAction.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.sql.plugin; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.async.DeleteAsyncResultAction; +import org.elasticsearch.xpack.core.async.DeleteAsyncResultRequest; + +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.DELETE; +import static org.elasticsearch.xpack.sql.proto.Protocol.ID_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_ASYNC_DELETE_REST_ENDPOINT; + +public class RestSqlAsyncDeleteResultsAction extends BaseRestHandler { + @Override + public List routes() { + return List.of(new Route(DELETE, SQL_ASYNC_DELETE_REST_ENDPOINT + "{" + ID_NAME + "}")); + } + + @Override + public String getName() { + return "sql_delete_async_result"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + DeleteAsyncResultRequest delete = new DeleteAsyncResultRequest(request.param(ID_NAME)); + return channel -> client.execute(DeleteAsyncResultAction.INSTANCE, delete, new RestToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncGetResultsAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncGetResultsAction.java new file mode 100644 index 0000000000000..3762f88d245d0 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncGetResultsAction.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.sql.plugin; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.xpack.core.async.GetAsyncResultRequest; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.sql.proto.Protocol.ID_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.KEEP_ALIVE_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_ASYNC_REST_ENDPOINT; +import static org.elasticsearch.xpack.sql.proto.Protocol.URL_PARAM_DELIMITER; +import static org.elasticsearch.xpack.sql.proto.Protocol.WAIT_FOR_COMPLETION_TIMEOUT_NAME; + +public class RestSqlAsyncGetResultsAction extends BaseRestHandler { + @Override + public List routes() { + return List.of(new Route(GET, SQL_ASYNC_REST_ENDPOINT + "{" + ID_NAME + "}")); + } + + @Override + public String getName() { + return "sql_get_async_result"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + GetAsyncResultRequest get = new GetAsyncResultRequest(request.param(ID_NAME)); + if (request.hasParam(WAIT_FOR_COMPLETION_TIMEOUT_NAME)) { + get.setWaitForCompletionTimeout(request.paramAsTime(WAIT_FOR_COMPLETION_TIMEOUT_NAME, get.getWaitForCompletionTimeout())); + } + if (request.hasParam(KEEP_ALIVE_NAME)) { + get.setKeepAlive(request.paramAsTime(KEEP_ALIVE_NAME, get.getKeepAlive())); + } + return channel -> client.execute(SqlAsyncGetResultsAction.INSTANCE, get, new SqlResponseListener(channel, request)); + } + + @Override + protected Set responseParams() { + return Collections.singleton(URL_PARAM_DELIMITER); + } + +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncGetStatusAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncGetStatusAction.java new file mode 100644 index 0000000000000..893ba834769c4 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlAsyncGetStatusAction.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.sql.plugin; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestStatusToXContentListener; +import org.elasticsearch.xpack.core.async.GetAsyncStatusRequest; + +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.GET; +import static org.elasticsearch.xpack.sql.proto.Protocol.ID_NAME; +import static org.elasticsearch.xpack.sql.proto.Protocol.SQL_ASYNC_STATUS_REST_ENDPOINT; + +public class RestSqlAsyncGetStatusAction extends BaseRestHandler { + @Override + public List routes() { + return List.of(new Route(GET, SQL_ASYNC_STATUS_REST_ENDPOINT + "{" + ID_NAME + "}")); + } + + @Override + public String getName() { + return "sql_get_async_status"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { + GetAsyncStatusRequest statusRequest = new GetAsyncStatusRequest(request.param(ID_NAME)); + return channel -> client.execute(SqlAsyncGetStatusAction.INSTANCE, statusRequest, new RestStatusToXContentListener<>(channel)); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlQueryAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlQueryAction.java index ac934571fe2e3..53b0b84c014a0 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlQueryAction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/RestSqlQueryAction.java @@ -10,36 +10,25 @@ import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.xcontent.MediaType; import org.elasticsearch.common.xcontent.MediaTypeRegistry; -import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.rest.BaseRestHandler; -import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestRequest; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.rest.action.RestResponseListener; +import org.elasticsearch.rest.action.RestCancellableNodeClient; import org.elasticsearch.xpack.sql.action.SqlQueryAction; import org.elasticsearch.xpack.sql.action.SqlQueryRequest; -import org.elasticsearch.xpack.sql.action.SqlQueryResponse; import org.elasticsearch.xpack.sql.proto.Protocol; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.List; -import java.util.Locale; import java.util.Set; -import static java.util.Collections.emptySet; import static org.elasticsearch.rest.RestRequest.Method.GET; import static org.elasticsearch.rest.RestRequest.Method.POST; import static org.elasticsearch.xpack.sql.proto.Protocol.URL_PARAM_DELIMITER; public class RestSqlQueryAction extends BaseRestHandler { - private final SqlMediaTypeParser sqlMediaTypeParser = new SqlMediaTypeParser(); - @Override public List routes() { return List.of( @@ -59,51 +48,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli sqlRequest = SqlQueryRequest.fromXContent(parser); } - MediaType responseMediaType = sqlMediaTypeParser.getResponseMediaType(request, sqlRequest); - if (responseMediaType == null) { - String msg = String.format(Locale.ROOT, "Invalid response content type: Accept=[%s], Content-Type=[%s], format=[%s]", - request.header("Accept"), request.header("Content-Type"), request.param("format")); - throw new IllegalArgumentException(msg); - } - - /* - * Special handling for the "delimiter" parameter which should only be - * checked for being present or not in the case of CSV format. We cannot - * override {@link BaseRestHandler#responseParams()} because this - * parameter should only be checked for CSV, not always. - */ - if ((responseMediaType instanceof XContentType || ((TextFormat) responseMediaType) != TextFormat.CSV) - && request.hasParam(URL_PARAM_DELIMITER)) { - throw new IllegalArgumentException(unrecognized(request, Collections.singleton(URL_PARAM_DELIMITER), emptySet(), "parameter")); - } - - long startNanos = System.nanoTime(); - return channel -> client.execute(SqlQueryAction.INSTANCE, sqlRequest, new RestResponseListener(channel) { - @Override - public RestResponse buildResponse(SqlQueryResponse response) throws Exception { - RestResponse restResponse; - - // XContent branch - if (responseMediaType instanceof XContentType) { - XContentType type = (XContentType) responseMediaType; - XContentBuilder builder = channel.newBuilder(request.getXContentType(), type, true); - response.toXContent(builder, request); - restResponse = new BytesRestResponse(RestStatus.OK, builder); - } else { // TextFormat - TextFormat type = (TextFormat) responseMediaType; - final String data = type.format(request, response); - - restResponse = new BytesRestResponse(RestStatus.OK, type.contentType(request), data.getBytes(StandardCharsets.UTF_8)); - - if (response.hasCursor()) { - restResponse.addHeader("Cursor", response.cursor()); - } - } - - restResponse.addHeader("Took-nanos", Long.toString(System.nanoTime() - startNanos)); - return restResponse; - } - }); + return channel -> { + RestCancellableNodeClient cancellableClient = new RestCancellableNodeClient(client, request.getHttpChannel()); + cancellableClient.execute(SqlQueryAction.INSTANCE, sqlRequest, new SqlResponseListener(channel, request, sqlRequest)); + }; } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlAsyncGetResultsAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlAsyncGetResultsAction.java new file mode 100644 index 0000000000000..cf3c422901877 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlAsyncGetResultsAction.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.sql.plugin; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.xpack.sql.action.SqlQueryResponse; + +import static org.elasticsearch.xpack.core.sql.SqlAsyncActionNames.SQL_ASYNC_GET_RESULT_ACTION_NAME; + +public class SqlAsyncGetResultsAction extends ActionType { + public static final SqlAsyncGetResultsAction INSTANCE = new SqlAsyncGetResultsAction(); + public static final String NAME = SQL_ASYNC_GET_RESULT_ACTION_NAME; + + private SqlAsyncGetResultsAction() { + super(NAME, SqlQueryResponse::new); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlAsyncGetStatusAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlAsyncGetStatusAction.java new file mode 100644 index 0000000000000..366d8c606f86b --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlAsyncGetStatusAction.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.sql.plugin; + +import org.elasticsearch.action.ActionType; +import org.elasticsearch.xpack.ql.async.QlStatusResponse; + +import static org.elasticsearch.xpack.core.sql.SqlAsyncActionNames.SQL_ASYNC_GET_STATUS_ACTION_NAME; + +public class SqlAsyncGetStatusAction extends ActionType { + public static final SqlAsyncGetStatusAction INSTANCE = new SqlAsyncGetStatusAction(); + public static final String NAME = SQL_ASYNC_GET_STATUS_ACTION_NAME; + + private SqlAsyncGetStatusAction() { + super(NAME, QlStatusResponse::new); + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlMediaTypeParser.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlMediaTypeParser.java index 7b5eaea4b70ff..dc764553928fb 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlMediaTypeParser.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlMediaTypeParser.java @@ -9,11 +9,13 @@ import org.elasticsearch.common.xcontent.MediaType; import org.elasticsearch.common.xcontent.MediaTypeRegistry; +import org.elasticsearch.common.xcontent.ParsedMediaType; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.xpack.sql.action.SqlQueryRequest; import org.elasticsearch.xpack.sql.proto.Mode; +import java.util.Locale; import static org.elasticsearch.xpack.sql.proto.Protocol.URL_PARAM_FORMAT; @@ -34,27 +36,49 @@ public class SqlMediaTypeParser { * isn't but there is a {@code Accept} header then we use that. If there * isn't then we use the {@code Content-Type} header which is required. */ - public MediaType getResponseMediaType(RestRequest request, SqlQueryRequest sqlRequest) { + public static MediaType getResponseMediaType(RestRequest request, SqlQueryRequest sqlRequest) { if (Mode.isDedicatedClient(sqlRequest.requestInfo().mode()) && (sqlRequest.binaryCommunication() == null || sqlRequest.binaryCommunication())) { // enforce CBOR response for drivers and CLI (unless instructed differently through the config param) return XContentType.CBOR; } else if (request.hasParam(URL_PARAM_FORMAT)) { - return validateColumnarRequest(sqlRequest.columnar(), - MEDIA_TYPE_REGISTRY.queryParamToMediaType(request.param(URL_PARAM_FORMAT))); + return validateColumnarRequest(sqlRequest.columnar(), mediaTypeFromParams(request), request); } - if (request.getParsedAccept() != null) { - return request.getParsedAccept().toMediaType(MEDIA_TYPE_REGISTRY); - } - return request.getXContentType(); + return mediaTypeFromHeaders(request); + } + + public static MediaType getResponseMediaType(RestRequest request) { + return request.hasParam(URL_PARAM_FORMAT) + ? checkNonNullMediaType(mediaTypeFromParams(request), request) + : mediaTypeFromHeaders(request); } - private static MediaType validateColumnarRequest(boolean requestIsColumnar, MediaType fromMediaType) { + private static MediaType mediaTypeFromHeaders(RestRequest request) { + ParsedMediaType acceptType = request.getParsedAccept(); + MediaType mediaType = acceptType != null ? acceptType.toMediaType(MEDIA_TYPE_REGISTRY) : request.getXContentType(); + return checkNonNullMediaType(mediaType, request); + } + + private static MediaType mediaTypeFromParams(RestRequest request) { + return MEDIA_TYPE_REGISTRY.queryParamToMediaType(request.param(URL_PARAM_FORMAT)); + } + + private static MediaType validateColumnarRequest(boolean requestIsColumnar, MediaType fromMediaType, RestRequest request) { if (requestIsColumnar && fromMediaType instanceof TextFormat) { throw new IllegalArgumentException("Invalid use of [columnar] argument: cannot be used in combination with " + "txt, csv or tsv formats"); } - return fromMediaType; + return checkNonNullMediaType(fromMediaType, request); + } + + private static MediaType checkNonNullMediaType(MediaType mediaType, RestRequest request) { + if (mediaType == null) { + String msg = String.format(Locale.ROOT, "Invalid request content type: Accept=[%s], Content-Type=[%s], format=[%s]", + request.header("Accept"), request.header("Content-Type"), request.param("format")); + throw new IllegalArgumentException(msg); + } + + return mediaType; } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlPlugin.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlPlugin.java index fa2674a926c2c..ab519eb241245 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlPlugin.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlPlugin.java @@ -106,7 +106,10 @@ public List getRestHandlers(Settings settings, RestController restC return Arrays.asList(new RestSqlQueryAction(), new RestSqlTranslateAction(), new RestSqlClearCursorAction(), - new RestSqlStatsAction()); + new RestSqlStatsAction(), + new RestSqlAsyncGetResultsAction(), + new RestSqlAsyncGetStatusAction(), + new RestSqlAsyncDeleteResultsAction()); } @Override @@ -118,6 +121,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(SqlTranslateAction.INSTANCE, TransportSqlTranslateAction.class), new ActionHandler<>(SqlClearCursorAction.INSTANCE, TransportSqlClearCursorAction.class), new ActionHandler<>(SqlStatsAction.INSTANCE, TransportSqlStatsAction.class), + new ActionHandler<>(SqlAsyncGetResultsAction.INSTANCE, TransportSqlAsyncGetResultsAction.class), + new ActionHandler<>(SqlAsyncGetStatusAction.INSTANCE, TransportSqlAsyncGetStatusAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlResponseListener.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlResponseListener.java new file mode 100644 index 0000000000000..200af40066d08 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/SqlResponseListener.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.sql.plugin; + +import org.elasticsearch.common.xcontent.MediaType; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.rest.action.RestResponseListener; +import org.elasticsearch.xpack.sql.action.SqlQueryRequest; +import org.elasticsearch.xpack.sql.action.SqlQueryResponse; + +import java.nio.charset.StandardCharsets; +import java.util.Locale; + +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_ASYNC_ID; +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_ASYNC_PARTIAL; +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_ASYNC_RUNNING; +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_CURSOR; +import static org.elasticsearch.xpack.sql.proto.Protocol.HEADER_NAME_TOOK_NANOS; +import static org.elasticsearch.xpack.sql.proto.Protocol.URL_PARAM_DELIMITER; + +class SqlResponseListener extends RestResponseListener { + + private final long startNanos = System.nanoTime(); + private final MediaType mediaType; + private final RestRequest request; + + + SqlResponseListener(RestChannel channel, RestRequest request, SqlQueryRequest sqlRequest) { + super(channel); + this.request = request; + + this.mediaType = SqlMediaTypeParser.getResponseMediaType(request, sqlRequest); + + /* + * Special handling for the "delimiter" parameter which should only be + * checked for being present or not in the case of CSV format. We cannot + * override {@link BaseRestHandler#responseParams()} because this + * parameter should only be checked for CSV, not always. + */ + if (mediaType != TextFormat.CSV && request.hasParam(URL_PARAM_DELIMITER)) { + String message = String.format(Locale.ROOT, "request [%s] contains unrecognized parameter: [" + URL_PARAM_DELIMITER + "]", + request.path()); + throw new IllegalArgumentException(message); + } + } + + SqlResponseListener(RestChannel channel, RestRequest request) { + super(channel); + this.request = request; + this.mediaType = SqlMediaTypeParser.getResponseMediaType(request); + } + + @Override + public RestResponse buildResponse(SqlQueryResponse response) throws Exception { + RestResponse restResponse; + + // XContent branch + if (mediaType instanceof XContentType) { + XContentType type = (XContentType) mediaType; + XContentBuilder builder = channel.newBuilder(request.getXContentType(), type, true); + response.toXContent(builder, request); + restResponse = new BytesRestResponse(RestStatus.OK, builder); + } else { // TextFormat + TextFormat type = (TextFormat) mediaType; + final String data = type.format(request, response); + + restResponse = new BytesRestResponse(RestStatus.OK, type.contentType(request), + data.getBytes(StandardCharsets.UTF_8)); + + if (response.hasCursor()) { + restResponse.addHeader(HEADER_NAME_CURSOR, response.cursor()); + } + + if (response.hasId()) { + restResponse.addHeader(HEADER_NAME_ASYNC_ID, response.id()); + restResponse.addHeader(HEADER_NAME_ASYNC_PARTIAL, String.valueOf(response.isPartial())); + restResponse.addHeader(HEADER_NAME_ASYNC_RUNNING, String.valueOf(response.isRunning())); + } + } + + restResponse.addHeader(HEADER_NAME_TOOK_NANOS, Long.toString(System.nanoTime() - startNanos)); + return restResponse; + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TextFormat.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TextFormat.java index c26f6f9d25391..44206ad9e950c 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TextFormat.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TextFormat.java @@ -73,13 +73,12 @@ String format(RestRequest request, SqlQueryResponse response) { } // format with header return formatter.formatWithHeader(response.columns(), response.rows()); - } - else { - // should be initialized (wrapped by the cursor) - if (formatter != null) { - // format without header - return formatter.formatWithoutHeader(response.rows()); - } + } else if (formatter != null) { // should be initialized (wrapped by the cursor) + // format without header + return formatter.formatWithoutHeader(response.rows()); + } else if (response.hasId()) { + // an async request has no results yet + return StringUtils.EMPTY; } // if this code is reached, it means it's a next page without cursor wrapping throw new SqlIllegalArgumentException("Cannot find text formatter - this is likely a bug"); diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlAsyncGetResultsAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlAsyncGetResultsAction.java new file mode 100644 index 0000000000000..d647298a8f888 --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlAsyncGetResultsAction.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.sql.plugin; + +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.ql.plugin.AbstractTransportQlAsyncGetResultsAction; +import org.elasticsearch.xpack.sql.action.SqlQueryResponse; +import org.elasticsearch.xpack.sql.action.SqlQueryTask; + +public class TransportSqlAsyncGetResultsAction extends AbstractTransportQlAsyncGetResultsAction { + + @Inject + public TransportSqlAsyncGetResultsAction(TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + NamedWriteableRegistry registry, + Client client, + ThreadPool threadPool, + BigArrays bigArrays) { + super(SqlAsyncGetResultsAction.NAME, transportService, actionFilters, clusterService, registry, client, threadPool, bigArrays, + SqlQueryTask.class); + } + + @Override + public Writeable.Reader responseReader() { + return SqlQueryResponse::new; + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlAsyncGetStatusAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlAsyncGetStatusAction.java new file mode 100644 index 0000000000000..fa9c8fbc5cd6b --- /dev/null +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlAsyncGetStatusAction.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.sql.plugin; + +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.ql.plugin.AbstractTransportQlAsyncGetStatusAction; +import org.elasticsearch.xpack.sql.action.SqlQueryResponse; +import org.elasticsearch.xpack.sql.action.SqlQueryTask; + + +public class TransportSqlAsyncGetStatusAction extends AbstractTransportQlAsyncGetStatusAction { + @Inject + public TransportSqlAsyncGetStatusAction(TransportService transportService, + ActionFilters actionFilters, + ClusterService clusterService, + NamedWriteableRegistry registry, + Client client, + ThreadPool threadPool, + BigArrays bigArrays) { + super(SqlAsyncGetStatusAction.NAME, transportService, actionFilters, clusterService, registry, client, threadPool, bigArrays, + SqlQueryTask.class); + } + + @Override + protected Writeable.Reader responseReader() { + return SqlQueryResponse::new; + } +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlQueryAction.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlQueryAction.java index 2e07e170a2d56..19c3f0faf2760 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlQueryAction.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/plugin/TransportSqlQueryAction.java @@ -14,19 +14,26 @@ import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.Tuple; import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.async.AsyncExecutionId; import org.elasticsearch.xpack.core.security.SecurityContext; +import org.elasticsearch.xpack.ql.async.AsyncTaskManagementService; import org.elasticsearch.xpack.ql.type.Schema; import org.elasticsearch.xpack.sql.SqlIllegalArgumentException; import org.elasticsearch.xpack.sql.action.SqlQueryAction; import org.elasticsearch.xpack.sql.action.SqlQueryRequest; import org.elasticsearch.xpack.sql.action.SqlQueryResponse; +import org.elasticsearch.xpack.sql.action.SqlQueryTask; import org.elasticsearch.xpack.sql.execution.PlanExecutor; import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape; import org.elasticsearch.xpack.sql.expression.literal.interval.Interval; @@ -40,29 +47,35 @@ import org.elasticsearch.xpack.sql.session.SqlConfiguration; import org.elasticsearch.xpack.sql.type.SqlDataTypes; +import java.io.IOException; import java.time.ZoneId; import java.util.ArrayList; import java.util.List; +import java.util.Map; import static java.util.Collections.unmodifiableList; import static org.elasticsearch.action.ActionListener.wrap; +import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; import static org.elasticsearch.xpack.ql.plugin.TransportActionUtils.executeRequestWithRetryAttempt; import static org.elasticsearch.xpack.sql.plugin.Transports.clusterName; import static org.elasticsearch.xpack.sql.plugin.Transports.username; import static org.elasticsearch.xpack.sql.proto.Mode.CLI; -public class TransportSqlQueryAction extends HandledTransportAction { +public class TransportSqlQueryAction extends HandledTransportAction + implements AsyncTaskManagementService.AsyncOperation { + private static final Logger log = LogManager.getLogger(TransportSqlQueryAction.class); private final SecurityContext securityContext; private final ClusterService clusterService; private final PlanExecutor planExecutor; private final SqlLicenseChecker sqlLicenseChecker; private final TransportService transportService; + private final AsyncTaskManagementService asyncTaskManagementService; @Inject public TransportSqlQueryAction(Settings settings, ClusterService clusterService, TransportService transportService, ThreadPool threadPool, ActionFilters actionFilters, PlanExecutor planExecutor, - SqlLicenseChecker sqlLicenseChecker) { + SqlLicenseChecker sqlLicenseChecker, BigArrays bigArrays) { super(SqlQueryAction.NAME, transportService, actionFilters, SqlQueryRequest::new); this.securityContext = XPackSettings.SECURITY_ENABLED.get(settings) ? @@ -71,42 +84,53 @@ public TransportSqlQueryAction(Settings settings, ClusterService clusterService, this.planExecutor = planExecutor; this.sqlLicenseChecker = sqlLicenseChecker; this.transportService = transportService; + + asyncTaskManagementService = new AsyncTaskManagementService<>(XPackPlugin.ASYNC_RESULTS_INDEX, planExecutor.client(), + ASYNC_SEARCH_ORIGIN, planExecutor.writeableRegistry(), taskManager, SqlQueryAction.INSTANCE.name(), this, SqlQueryTask.class, + clusterService, threadPool, bigArrays); } @Override protected void doExecute(Task task, SqlQueryRequest request, ActionListener listener) { sqlLicenseChecker.checkIfSqlAllowed(request.mode()); - operation(planExecutor, request, listener, username(securityContext), clusterName(clusterService), transportService, - clusterService); + if (request.waitForCompletionTimeout() != null && request.waitForCompletionTimeout().getMillis() >= 0) { + asyncTaskManagementService.asyncExecute(request, request.waitForCompletionTimeout(), request.keepAlive(), + request.keepOnCompletion(), listener); + } else { + operation(planExecutor, (SqlQueryTask) task, request, listener, username(securityContext), transportService, clusterService); + } } /** * Actual implementation of the action. Statically available to support embedded mode. */ - static void operation(PlanExecutor planExecutor, SqlQueryRequest request, ActionListener listener, - String username, String clusterName, TransportService transportService, ClusterService clusterService) { + public static void operation(PlanExecutor planExecutor, SqlQueryTask task, SqlQueryRequest request, + ActionListener listener, String username, TransportService transportService, + ClusterService clusterService) { // The configuration is always created however when dealing with the next page, only the timeouts are relevant // the rest having default values (since the query is already created) SqlConfiguration cfg = new SqlConfiguration(request.zoneId(), request.fetchSize(), request.requestTimeout(), request.pageTimeout(), - request.filter(), request.runtimeMappings(), request.mode(), request.clientId(), request.version(), username, clusterName, - request.fieldMultiValueLeniency(), request.indexIncludeFrozen()); + request.filter(), request.runtimeMappings(), request.mode(), request.clientId(), request.version(), username, + clusterName(clusterService), request.fieldMultiValueLeniency(), request.indexIncludeFrozen(), + new TaskId(clusterService.localNode().getId(), task.getId()), task, + request.waitForCompletionTimeout(), request.keepOnCompletion(), request.keepAlive()); if (Strings.hasText(request.cursor()) == false) { executeRequestWithRetryAttempt(clusterService, listener::onFailure, onFailure -> planExecutor.sql(cfg, request.query(), request.params(), - wrap(p -> listener.onResponse(createResponseWithSchema(request, p)), onFailure)), + wrap(p -> listener.onResponse(createResponseWithSchema(request, p, task)), onFailure)), node -> transportService.sendRequest(node, SqlQueryAction.NAME, request, new ActionListenerResponseHandler<>(listener, SqlQueryResponse::new, ThreadPool.Names.SAME)), log); } else { Tuple decoded = Cursors.decodeFromStringWithZone(request.cursor()); planExecutor.nextPage(cfg, decoded.v1(), - wrap(p -> listener.onResponse(createResponse(request, decoded.v2(), null, p)), + wrap(p -> listener.onResponse(createResponse(request, decoded.v2(), null, p, task)), listener::onFailure)); } } - private static SqlQueryResponse createResponseWithSchema(SqlQueryRequest request, Page page) { + private static SqlQueryResponse createResponseWithSchema(SqlQueryRequest request, Page page, SqlQueryTask task) { RowSet rset = page.rowSet(); if ((rset instanceof SchemaRowSet) == false) { throw new SqlIllegalArgumentException("No schema found inside {}", rset.getClass()); @@ -122,10 +146,11 @@ private static SqlQueryResponse createResponseWithSchema(SqlQueryRequest request } } columns = unmodifiableList(columns); - return createResponse(request, request.zoneId(), columns, page); + return createResponse(request, request.zoneId(), columns, page, task); } - private static SqlQueryResponse createResponse(SqlQueryRequest request, ZoneId zoneId, List header, Page page) { + private static SqlQueryResponse createResponse(SqlQueryRequest request, ZoneId zoneId, List header, Page page, + SqlQueryTask task) { List> rows = new ArrayList<>(); page.rowSet().forEachRow(rowView -> { List row = new ArrayList<>(rowView.columnCount()); @@ -133,13 +158,17 @@ private static SqlQueryResponse createResponse(SqlQueryRequest request, ZoneId z rows.add(unmodifiableList(row)); }); + AsyncExecutionId executionId = task.getExecutionId(); return new SqlQueryResponse( Cursors.encodeToString(page.next(), zoneId), request.mode(), request.version(), request.columnar(), header, - rows); + rows, + executionId == null ? null : executionId.getEncoded(), + false, false + ); } @SuppressWarnings("rawtypes") @@ -162,4 +191,26 @@ private static Object value(Object r, Mode mode) { return r; } + + @Override + public SqlQueryTask createTask(SqlQueryRequest request, long id, String type, String action, TaskId parentTaskId, + Map headers, Map originHeaders, AsyncExecutionId asyncExecutionId) { + return new SqlQueryTask(id, type, action, request.getDescription(), parentTaskId, headers, originHeaders, asyncExecutionId, + request.keepAlive(), request.mode(), request.version(), request.columnar()); + } + + @Override + public void execute(SqlQueryRequest request, SqlQueryTask task, ActionListener listener) { + operation(planExecutor, task, request, listener, username(securityContext), transportService, clusterService); + } + + @Override + public SqlQueryResponse initialResponse(SqlQueryTask task) { + return task.getCurrentResult(); + } + + @Override + public SqlQueryResponse readResponse(StreamInput inputStream) throws IOException { + return new SqlQueryResponse(inputStream); + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/session/SqlConfiguration.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/session/SqlConfiguration.java index 4e904cc2a4867..1ca380832aa43 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/session/SqlConfiguration.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/session/SqlConfiguration.java @@ -10,7 +10,10 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.xpack.sql.action.SqlQueryTask; import org.elasticsearch.xpack.sql.proto.Mode; +import org.elasticsearch.xpack.sql.proto.Protocol; import org.elasticsearch.xpack.sql.proto.SqlVersion; import java.time.ZoneId; @@ -27,6 +30,14 @@ public class SqlConfiguration extends org.elasticsearch.xpack.ql.session.Configu private final SqlVersion version; private final boolean multiValueFieldLeniency; private final boolean includeFrozenIndices; + private final TimeValue waitForCompletionTimeout; + private final boolean keepOnCompletion; + private final TimeValue keepAlive; + + @Nullable + private final TaskId taskId; + @Nullable + private final SqlQueryTask task; @Nullable private QueryBuilder filter; @@ -39,7 +50,10 @@ public SqlConfiguration(ZoneId zi, int pageSize, TimeValue requestTimeout, TimeV Mode mode, String clientId, SqlVersion version, String username, String clusterName, boolean multiValueFieldLeniency, - boolean includeFrozen) { + boolean includeFrozen, + @Nullable TaskId taskId, + @Nullable SqlQueryTask task, + TimeValue waitForCompletionTimeout, boolean keepOnCompletion, TimeValue keepAlive) { super(zi, username, clusterName); @@ -53,6 +67,22 @@ public SqlConfiguration(ZoneId zi, int pageSize, TimeValue requestTimeout, TimeV this.version = version != null ? version : SqlVersion.fromId(Version.CURRENT.id); this.multiValueFieldLeniency = multiValueFieldLeniency; this.includeFrozenIndices = includeFrozen; + this.taskId = taskId; + this.task = task; + this.waitForCompletionTimeout = waitForCompletionTimeout; + this.keepOnCompletion = keepOnCompletion; + this.keepAlive = keepAlive; + } + + public SqlConfiguration(ZoneId zi, int pageSize, TimeValue requestTimeout, TimeValue pageTimeout, QueryBuilder filter, + Map runtimeMappings, + Mode mode, String clientId, SqlVersion version, + String username, String clusterName, + boolean multiValueFieldLeniency, + boolean includeFrozen) { + this(zi, pageSize, requestTimeout, pageTimeout, filter, runtimeMappings, mode, clientId, version, username, clusterName, + multiValueFieldLeniency, includeFrozen, null, null, Protocol.DEFAULT_WAIT_FOR_COMPLETION_TIMEOUT, + Protocol.DEFAULT_KEEP_ON_COMPLETION, Protocol.DEFAULT_KEEP_ALIVE); } public int pageSize() { @@ -94,4 +124,24 @@ public boolean includeFrozen() { public SqlVersion version() { return version; } + + public TaskId taskId() { + return taskId; + } + + public SqlQueryTask task() { + return task; + } + + public TimeValue waitForCompletionTimeout() { + return waitForCompletionTimeout; + } + + public boolean keepOnCompletion() { + return keepOnCompletion; + } + + public TimeValue keepAlive() { + return keepAlive; + } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/session/SqlSession.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/session/SqlSession.java index 7fa9069fa244c..accd85d60541d 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/session/SqlSession.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/session/SqlSession.java @@ -8,7 +8,9 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.Client; +import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.common.Strings; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.ql.index.IndexResolution; import org.elasticsearch.xpack.ql.index.IndexResolver; @@ -55,7 +57,7 @@ public SqlSession(SqlConfiguration configuration, Client client, FunctionRegistr Optimizer optimizer, Planner planner, PlanExecutor planExecutor) { - this.client = client; + this.client = configuration.taskId() != null ? new ParentTaskAssigningClient(client, configuration.taskId()) : client; this.functionRegistry = functionRegistry; this.indexResolver = indexResolver; @@ -125,6 +127,11 @@ public void debugAnalyzedPlan(LogicalPlan parsed, ActionListener void preAnalyze(LogicalPlan parsed, Function action, ActionListener listener) { + if (configuration.task() != null && configuration.task().isCancelled()) { + listener.onFailure(new TaskCancelledException("cancelled")); + return; + } + PreAnalysis preAnalysis = preAnalyzer.preAnalyze(parsed); // TODO we plan to support joins in the future when possible, but for now we'll just fail early if we see one if (preAnalysis.indices.size() > 1) { diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/SqlTestUtils.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/SqlTestUtils.java index c76a7e596e1f5..5b4f0266e6994 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/SqlTestUtils.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/SqlTestUtils.java @@ -8,8 +8,12 @@ package org.elasticsearch.xpack.sql; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.xpack.core.async.AsyncExecutionId; import org.elasticsearch.xpack.ql.expression.Literal; import org.elasticsearch.xpack.ql.tree.Source; +import org.elasticsearch.xpack.sql.action.SqlQueryAction; +import org.elasticsearch.xpack.sql.action.SqlQueryTask; import org.elasticsearch.xpack.sql.proto.Mode; import org.elasticsearch.xpack.sql.proto.Protocol; import org.elasticsearch.xpack.sql.proto.SqlVersion; @@ -25,6 +29,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static java.util.Collections.emptyMap; import static org.elasticsearch.test.ESTestCase.randomAlphaOfLength; import static org.elasticsearch.test.ESTestCase.randomBoolean; import static org.elasticsearch.test.ESTestCase.randomFrom; @@ -42,52 +47,45 @@ private SqlTestUtils() {} Protocol.REQUEST_TIMEOUT, Protocol.PAGE_TIMEOUT, null, null, Mode.PLAIN, null, null, null, null, false, false); - public static SqlConfiguration randomConfiguration() { - return new SqlConfiguration(randomZone(), - randomIntBetween(0, 1000), - new TimeValue(randomNonNegativeLong()), - new TimeValue(randomNonNegativeLong()), - null, - null, - randomFrom(Mode.values()), - randomAlphaOfLength(10), - null, - randomAlphaOfLength(10), - randomAlphaOfLength(10), - false, - randomBoolean()); - } - - public static SqlConfiguration randomConfiguration(ZoneId providedZoneId) { - return new SqlConfiguration(providedZoneId, + public static SqlConfiguration randomConfiguration(ZoneId providedZoneId, SqlVersion sqlVersion) { + Mode mode = randomFrom(Mode.values()); + long taskId = randomNonNegativeLong(); + return new SqlConfiguration(providedZoneId != null ? providedZoneId : randomZone(), randomIntBetween(0, 1000), new TimeValue(randomNonNegativeLong()), new TimeValue(randomNonNegativeLong()), null, null, - randomFrom(Mode.values()), + mode, randomAlphaOfLength(10), - null, + sqlVersion, randomAlphaOfLength(10), randomAlphaOfLength(10), false, - randomBoolean()); + randomBoolean(), + new TaskId(randomAlphaOfLength(10), taskId), + randomTask(taskId, mode, sqlVersion), + new TimeValue(randomNonNegativeLong()), + randomBoolean(), + new TimeValue(randomNonNegativeLong())); + } + + public static SqlConfiguration randomConfiguration() { + return randomConfiguration(null, null); + } + + public static SqlConfiguration randomConfiguration(ZoneId providedZoneId) { + return randomConfiguration(providedZoneId, null); } public static SqlConfiguration randomConfiguration(SqlVersion version) { - return new SqlConfiguration(randomZone(), - randomIntBetween(0, 1000), - new TimeValue(randomNonNegativeLong()), - new TimeValue(randomNonNegativeLong()), - null, - null, - randomFrom(Mode.values()), - randomAlphaOfLength(10), - version, - randomAlphaOfLength(10), - randomAlphaOfLength(10), - false, - randomBoolean()); + return randomConfiguration(null, version); + } + + public static SqlQueryTask randomTask(long taskId, Mode mode, SqlVersion sqlVersion) { + return new SqlQueryTask(taskId, "transport", SqlQueryAction.NAME, "", null, emptyMap(), emptyMap(), + new AsyncExecutionId("", new TaskId(randomAlphaOfLength(10), 1)), TimeValue.timeValueDays(5), mode, sqlVersion, + randomBoolean()); } public static String randomWhitespaces() { diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/CancellationTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/CancellationTests.java new file mode 100644 index 0000000000000..691fb189a5b46 --- /dev/null +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/CancellationTests.java @@ -0,0 +1,234 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.sql.analysis; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.fieldcaps.FieldCapabilities; +import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchRequestBuilder; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.sql.action.SqlQueryAction; +import org.elasticsearch.xpack.sql.action.SqlQueryRequest; +import org.elasticsearch.xpack.sql.action.SqlQueryRequestBuilder; +import org.elasticsearch.xpack.sql.action.SqlQueryResponse; +import org.elasticsearch.xpack.sql.action.SqlQueryTask; +import org.elasticsearch.xpack.sql.execution.PlanExecutor; +import org.elasticsearch.xpack.sql.plugin.TransportSqlQueryAction; +import org.elasticsearch.xpack.ql.index.IndexResolver; +import org.elasticsearch.xpack.ql.type.DefaultDataTypeRegistry; +import org.mockito.ArgumentCaptor; +import org.mockito.stubbing.Answer; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class CancellationTests extends ESTestCase { + + public void testCancellationBeforeFieldCaps() throws InterruptedException { + Client client = mock(Client.class); + SqlQueryTask task = mock(SqlQueryTask.class); + when(task.isCancelled()).thenReturn(true); + ClusterService mockClusterService = mockClusterService(); + + IndexResolver indexResolver = new IndexResolver(client, randomAlphaOfLength(10), DefaultDataTypeRegistry.INSTANCE); + PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList())); + CountDownLatch countDownLatch = new CountDownLatch(1); + SqlQueryRequest request = new SqlQueryRequestBuilder(client, SqlQueryAction.INSTANCE).query("SELECT foo FROM bar").request(); + TransportSqlQueryAction.operation(planExecutor, task, request, new ActionListener<>() { + @Override + public void onResponse(SqlQueryResponse sqlSearchResponse) { + fail("Shouldn't be here"); + countDownLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + assertThat(e, instanceOf(TaskCancelledException.class)); + countDownLatch.countDown(); + } + }, "", mock(TransportService.class), mockClusterService); + countDownLatch.await(); + verify(task, times(1)).isCancelled(); + verify(task, times(1)).getId(); + verify(client, times(1)).settings(); + verify(client, times(1)).threadPool(); + verifyNoMoreInteractions(client, task); + } + + private Map> fields(String[] indices) { + FieldCapabilities fooField = + new FieldCapabilities("foo", "integer", false, true, true, indices, null, null, emptyMap()); + FieldCapabilities categoryField = + new FieldCapabilities("event.category", "keyword", false, true, true, indices, null, null, emptyMap()); + FieldCapabilities timestampField = + new FieldCapabilities("@timestamp", "date", false, true, true, indices, null, null, emptyMap()); + Map> fields = new HashMap<>(); + fields.put(fooField.getName(), singletonMap(fooField.getName(), fooField)); + fields.put(categoryField.getName(), singletonMap(categoryField.getName(), categoryField)); + fields.put(timestampField.getName(), singletonMap(timestampField.getName(), timestampField)); + return fields; + } + + public void testCancellationBeforeSearch() throws InterruptedException { + Client client = mock(Client.class); + + AtomicBoolean cancelled = new AtomicBoolean(false); + SqlQueryTask task = mock(SqlQueryTask.class); + long taskId = randomNonNegativeLong(); + when(task.isCancelled()).then(invocationOnMock -> cancelled.get()); + when(task.getId()).thenReturn(taskId); + ClusterService mockClusterService = mockClusterService(); + + String[] indices = new String[]{"endgame"}; + + FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); + when(fieldCapabilitiesResponse.getIndices()).thenReturn(indices); + when(fieldCapabilitiesResponse.get()).thenReturn(fields(indices)); + doAnswer((Answer) invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + assertFalse(cancelled.getAndSet(true)); + listener.onResponse(fieldCapabilitiesResponse); + return null; + }).when(client).fieldCaps(any(), any()); + + + IndexResolver indexResolver = new IndexResolver(client, randomAlphaOfLength(10), DefaultDataTypeRegistry.INSTANCE); + PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList())); + CountDownLatch countDownLatch = new CountDownLatch(1); + SqlQueryRequest request = new SqlQueryRequestBuilder(client, SqlQueryAction.INSTANCE) + .query("SELECT foo FROM " + indices[0]).request(); + TransportSqlQueryAction.operation(planExecutor, task, request, new ActionListener<>() { + @Override + public void onResponse(SqlQueryResponse sqlSearchResponse) { + fail("Shouldn't be here"); + countDownLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + assertThat(e, instanceOf(TaskCancelledException.class)); + countDownLatch.countDown(); + } + }, "", mock(TransportService.class), mockClusterService); + countDownLatch.await(); + verify(client, times(1)).fieldCaps(any(), any()); + verify(task, times(2)).isCancelled(); + verify(task, times(1)).getId(); + verify(client, times(1)).settings(); + verify(client, times(1)).threadPool(); + verifyNoMoreInteractions(client, task); + } + + public void testCancellationDuringSearch() throws InterruptedException { + Client client = mock(Client.class); + + SqlQueryTask task = mock(SqlQueryTask.class); + String nodeId = randomAlphaOfLength(10); + long taskId = randomNonNegativeLong(); + when(task.isCancelled()).thenReturn(false); + when(task.getId()).thenReturn(taskId); + ClusterService mockClusterService = mockClusterService(nodeId); + + String[] indices = new String[]{"endgame"}; + + // Emulation of field capabilities + FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); + when(fieldCapabilitiesResponse.getIndices()).thenReturn(indices); + when(fieldCapabilitiesResponse.get()).thenReturn(fields(indices)); + doAnswer((Answer) invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(fieldCapabilitiesResponse); + return null; + }).when(client).fieldCaps(any(), any()); + + // Emulation of search cancellation + ArgumentCaptor searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class); + when(client.prepareSearch(any())).thenReturn(new SearchRequestBuilder(client, SearchAction.INSTANCE).setIndices(indices)); + doAnswer((Answer) invocation -> { + @SuppressWarnings("unchecked") + SearchRequest request = (SearchRequest) invocation.getArguments()[1]; + TaskId parentTask = request.getParentTask(); + assertNotNull(parentTask); + assertEquals(taskId, parentTask.getId()); + assertEquals(nodeId, parentTask.getNodeId()); + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new TaskCancelledException("cancelled")); + return null; + }).when(client).execute(any(), searchRequestCaptor.capture(), any()); + + IndexResolver indexResolver = new IndexResolver(client, randomAlphaOfLength(10), DefaultDataTypeRegistry.INSTANCE); + PlanExecutor planExecutor = new PlanExecutor(client, indexResolver, new NamedWriteableRegistry(Collections.emptyList())); + SqlQueryRequest request = new SqlQueryRequestBuilder(client, SqlQueryAction.INSTANCE) + .query("SELECT foo FROM " + indices[0]).request(); + CountDownLatch countDownLatch = new CountDownLatch(1); + TransportSqlQueryAction.operation(planExecutor, task, request, new ActionListener<>() { + @Override + public void onResponse(SqlQueryResponse sqlSearchResponse) { + fail("Shouldn't be here"); + countDownLatch.countDown(); + } + + @Override + public void onFailure(Exception e) { + assertThat(e, instanceOf(TaskCancelledException.class)); + countDownLatch.countDown(); + } + }, "", mock(TransportService.class), mockClusterService); + countDownLatch.await(); + // Final verification to ensure no more interaction + verify(client).fieldCaps(any(), any()); + verify(client).execute(any(), any(), any()); + verify(task, times(2)).isCancelled(); + verify(task, times(1)).getId(); + verify(client, times(1)).settings(); + verify(client, times(1)).threadPool(); + verifyNoMoreInteractions(client, task); + } + + private ClusterService mockClusterService() { + return mockClusterService(null); + } + + private ClusterService mockClusterService(String nodeId) { + final ClusterService mockClusterService = mock(ClusterService.class); + final DiscoveryNode mockNode = mock(DiscoveryNode.class); + final ClusterName mockClusterName = mock(ClusterName.class); + when(mockNode.getId()).thenReturn(nodeId == null ? randomAlphaOfLength(10) : nodeId); + when(mockClusterService.localNode()).thenReturn(mockNode); + when(mockClusterName.value()).thenReturn(randomAlphaOfLength(10)); + when(mockClusterService.getClusterName()).thenReturn(mockClusterName); + return mockClusterService; + } +} diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plugin/SqlMediaTypeParserTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plugin/SqlMediaTypeParserTests.java index 1b33712d1f157..3da565a0b7fec 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plugin/SqlMediaTypeParserTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plugin/SqlMediaTypeParserTests.java @@ -20,61 +20,69 @@ import java.util.Collections; import java.util.Map; +import static org.elasticsearch.xpack.sql.plugin.SqlMediaTypeParser.getResponseMediaType; import static org.elasticsearch.xpack.sql.plugin.TextFormat.CSV; import static org.elasticsearch.xpack.sql.plugin.TextFormat.PLAIN_TEXT; import static org.elasticsearch.xpack.sql.plugin.TextFormat.TSV; import static org.elasticsearch.xpack.sql.proto.RequestInfo.CLIENT_IDS; import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.nullValue; public class SqlMediaTypeParserTests extends ESTestCase { - SqlMediaTypeParser parser = new SqlMediaTypeParser(); public void testPlainTextDetection() { - MediaType text = parser.getResponseMediaType(reqWithAccept("text/plain"), createTestInstance(false, Mode.PLAIN, false)); + MediaType text = getResponseMediaType(reqWithAccept("text/plain"), createTestInstance(false, Mode.PLAIN, false)); assertThat(text, is(PLAIN_TEXT)); } public void testCsvDetection() { - MediaType text = parser.getResponseMediaType(reqWithAccept("text/csv"), createTestInstance(false, Mode.PLAIN, false)); + MediaType text = getResponseMediaType(reqWithAccept("text/csv"), createTestInstance(false, Mode.PLAIN, false)); assertThat(text, is(CSV)); - text = parser.getResponseMediaType(reqWithAccept("text/csv; delimiter=x"), createTestInstance(false, Mode.PLAIN, false)); + text = getResponseMediaType(reqWithAccept("text/csv; delimiter=x"), createTestInstance(false, Mode.PLAIN, false)); assertThat(text, is(CSV)); } public void testTsvDetection() { - MediaType text = parser.getResponseMediaType(reqWithAccept("text/tab-separated-values"), + MediaType text = getResponseMediaType(reqWithAccept("text/tab-separated-values"), createTestInstance(false, Mode.PLAIN, false)); assertThat(text, is(TSV)); } public void testMediaTypeDetectionWithParameters() { - assertThat(parser.getResponseMediaType(reqWithAccept("text/plain; charset=utf-8"), + assertThat(getResponseMediaType(reqWithAccept("text/plain; charset=utf-8"), createTestInstance(false, Mode.PLAIN, false)), is(PLAIN_TEXT)); - assertThat(parser.getResponseMediaType(reqWithAccept("text/plain; header=present"), + assertThat(getResponseMediaType(reqWithAccept("text/plain; header=present"), createTestInstance(false, Mode.PLAIN, false)), is(PLAIN_TEXT)); - assertThat(parser.getResponseMediaType(reqWithAccept("text/plain; charset=utf-8; header=present"), + assertThat(getResponseMediaType(reqWithAccept("text/plain; charset=utf-8; header=present"), createTestInstance(false, Mode.PLAIN, false)), is(PLAIN_TEXT)); - assertThat(parser.getResponseMediaType(reqWithAccept("text/csv; charset=utf-8"), + assertThat(getResponseMediaType(reqWithAccept("text/csv; charset=utf-8"), createTestInstance(false, Mode.PLAIN, false)), is(CSV)); - assertThat(parser.getResponseMediaType(reqWithAccept("text/csv; header=present"), + assertThat(getResponseMediaType(reqWithAccept("text/csv; header=present"), createTestInstance(false, Mode.PLAIN, false)), is(CSV)); - assertThat(parser.getResponseMediaType(reqWithAccept("text/csv; charset=utf-8; header=present"), + assertThat(getResponseMediaType(reqWithAccept("text/csv; charset=utf-8; header=present"), createTestInstance(false, Mode.PLAIN, false)), is(CSV)); - assertThat(parser.getResponseMediaType(reqWithAccept("text/tab-separated-values; charset=utf-8"), + assertThat(getResponseMediaType(reqWithAccept("text/tab-separated-values; charset=utf-8"), createTestInstance(false, Mode.PLAIN, false)), is(TSV)); - assertThat(parser.getResponseMediaType(reqWithAccept("text/tab-separated-values; header=present"), + assertThat(getResponseMediaType(reqWithAccept("text/tab-separated-values; header=present"), createTestInstance(false, Mode.PLAIN, false)), is(TSV)); - assertThat(parser.getResponseMediaType(reqWithAccept("text/tab-separated-values; charset=utf-8; header=present"), + assertThat(getResponseMediaType(reqWithAccept("text/tab-separated-values; charset=utf-8; header=present"), createTestInstance(false, Mode.PLAIN, false)), is(TSV)); } public void testInvalidFormat() { - MediaType mediaType = parser.getResponseMediaType(reqWithAccept("text/garbage"), createTestInstance(false, Mode.PLAIN, false)); - assertThat(mediaType, is(nullValue())); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> getResponseMediaType(reqWithAccept("text/garbage"), createTestInstance(false, Mode.PLAIN, false))); + assertEquals(e.getMessage(), + "Invalid request content type: Accept=[text/garbage], Content-Type=[application/json], format=[null]"); + } + + public void testNoFormat() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> getResponseMediaType(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(), + createTestInstance(false, Mode.PLAIN, false))); + assertEquals(e.getMessage(), "Invalid request content type: Accept=[null], Content-Type=[null], format=[null]"); } private static RestRequest reqWithAccept(String acceptHeader) { @@ -90,6 +98,7 @@ protected SqlQueryRequest createTestInstance(boolean binaryCommunication, Mode m randomZone(), between(1, Integer.MAX_VALUE), TimeValue.parseTimeValue(randomTimeValue(), null, "test"), TimeValue.parseTimeValue(randomTimeValue(), null, "test"), columnar, randomAlphaOfLength(10), new RequestInfo(mode, randomFrom(randomFrom(CLIENT_IDS), randomAlphaOfLengthBetween(10, 20))), - randomBoolean(), randomBoolean()).binaryCommunication(binaryCommunication); + randomBoolean(), randomBoolean(), TimeValue.parseTimeValue(randomTimeValue(), null, "test"), + randomBoolean(), TimeValue.parseTimeValue(randomTimeValue(), null, "test")).binaryCommunication(binaryCommunication); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plugin/SqlPluginTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plugin/SqlPluginTests.java index d559d69607b4d..15d345c5acada 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plugin/SqlPluginTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/plugin/SqlPluginTests.java @@ -30,12 +30,12 @@ public void testSqlDisabledIsNoOp() { SqlPlugin plugin = new SqlPlugin(settings); assertThat(plugin.createComponents(mock(Client.class), "cluster", new NamedWriteableRegistry(Cursors.getNamedWriteables())), hasSize(3)); - assertThat(plugin.getActions(), hasSize(6)); + assertThat(plugin.getActions(), hasSize(8)); assertThat( plugin.getRestHandlers(Settings.EMPTY, mock(RestController.class), new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS), IndexScopedSettings.DEFAULT_SCOPED_SETTINGS, new SettingsFilter(Collections.emptyList()), mock(IndexNameExpressionResolver.class), () -> mock(DiscoveryNodes.class)), - hasSize(4)); + hasSize(7)); } }