Skip to content

Commit 273ac15

Browse files
authored
Support canceling cross-clusters search requests (#66206)
This commit supports canceling cross-clusters search requests. Several important changes in this commit: - Set the parent task for CCS search requests - Keep track of underlying connections instead of proxy connections - Assign the parent task for proxy requests
1 parent eff2c26 commit 273ac15

File tree

13 files changed

+163
-48
lines changed

13 files changed

+163
-48
lines changed

server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import org.elasticsearch.search.builder.SearchSourceBuilder;
6868
import org.elasticsearch.search.fetch.FetchSubPhase;
6969
import org.elasticsearch.search.fetch.FetchSubPhaseProcessor;
70+
import org.elasticsearch.tasks.TaskId;
7071
import org.elasticsearch.test.ESIntegTestCase;
7172

7273
import java.io.IOException;
@@ -131,9 +132,10 @@ public void testLocalClusterAlias() {
131132
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL);
132133
IndexResponse indexResponse = client().index(indexRequest).actionGet();
133134
assertEquals(RestStatus.CREATED, indexResponse.status());
135+
TaskId parentTaskId = new TaskId("node", randomNonNegativeLong());
134136

135137
{
136-
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
138+
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY,
137139
"local", nowInMillis, randomBoolean());
138140
SearchResponse searchResponse = client().search(searchRequest).actionGet();
139141
assertEquals(1, searchResponse.getHits().getTotalHits().value);
@@ -145,7 +147,7 @@ public void testLocalClusterAlias() {
145147
assertEquals("1", hit.getId());
146148
}
147149
{
148-
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY,
150+
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(), Strings.EMPTY_ARRAY,
149151
"", nowInMillis, randomBoolean());
150152
SearchResponse searchResponse = client().search(searchRequest).actionGet();
151153
assertEquals(1, searchResponse.getHits().getTotalHits().value);
@@ -159,6 +161,7 @@ public void testLocalClusterAlias() {
159161
}
160162

161163
public void testAbsoluteStartMillis() {
164+
TaskId parentTaskId = new TaskId("node", randomNonNegativeLong());
162165
{
163166
IndexRequest indexRequest = new IndexRequest("test-1970.01.01");
164167
indexRequest.id("1");
@@ -187,21 +190,21 @@ public void testAbsoluteStartMillis() {
187190
assertEquals(0, searchResponse.getTotalShards());
188191
}
189192
{
190-
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
193+
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
191194
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
192195
SearchResponse searchResponse = client().search(searchRequest).actionGet();
193196
assertEquals(2, searchResponse.getHits().getTotalHits().value);
194197
}
195198
{
196-
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
199+
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
197200
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
198201
searchRequest.indices("<test-{now/d}>");
199202
SearchResponse searchResponse = client().search(searchRequest).actionGet();
200203
assertEquals(1, searchResponse.getHits().getTotalHits().value);
201204
assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex());
202205
}
203206
{
204-
SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(),
207+
SearchRequest searchRequest = SearchRequest.subSearchRequest(parentTaskId, new SearchRequest(),
205208
Strings.EMPTY_ARRAY, "", 0, randomBoolean());
206209
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
207210
RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date");
@@ -217,6 +220,7 @@ public void testAbsoluteStartMillis() {
217220

218221
public void testFinalReduce() {
219222
long nowInMillis = randomLongBetween(0, Long.MAX_VALUE);
223+
TaskId taskId = new TaskId("node", randomNonNegativeLong());
220224
{
221225
IndexRequest indexRequest = new IndexRequest("test");
222226
indexRequest.id("1");
@@ -243,7 +247,7 @@ public void testFinalReduce() {
243247
source.aggregation(terms);
244248

245249
{
246-
SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest,
250+
SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(taskId, originalRequest,
247251
Strings.EMPTY_ARRAY, "remote", nowInMillis, true);
248252
SearchResponse searchResponse = client().search(searchRequest).actionGet();
249253
assertEquals(2, searchResponse.getHits().getTotalHits().value);
@@ -252,7 +256,7 @@ public void testFinalReduce() {
252256
assertEquals(1, longTerms.getBuckets().size());
253257
}
254258
{
255-
SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest,
259+
SearchRequest searchRequest = SearchRequest.subSearchRequest(taskId, originalRequest,
256260
Strings.EMPTY_ARRAY, "remote", nowInMillis, false);
257261
SearchResponse searchResponse = client().search(searchRequest).actionGet();
258262
assertEquals(2, searchResponse.getHits().getTotalHits().value);

server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
package org.elasticsearch.search.ccs;
2121

22+
import org.elasticsearch.action.ActionFuture;
23+
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
24+
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
25+
import org.elasticsearch.action.search.SearchAction;
2226
import org.elasticsearch.action.search.SearchRequest;
2327
import org.elasticsearch.action.search.SearchResponse;
2428
import org.elasticsearch.action.support.PlainActionFuture;
@@ -27,6 +31,7 @@
2731
import org.elasticsearch.cluster.node.DiscoveryNode;
2832
import org.elasticsearch.cluster.node.DiscoveryNodeRole;
2933
import org.elasticsearch.common.settings.Settings;
34+
import org.elasticsearch.common.unit.TimeValue;
3035
import org.elasticsearch.common.util.CollectionUtils;
3136
import org.elasticsearch.index.IndexModule;
3237
import org.elasticsearch.index.query.MatchAllQueryBuilder;
@@ -36,11 +41,13 @@
3641
import org.elasticsearch.search.builder.SearchSourceBuilder;
3742
import org.elasticsearch.search.internal.SearchContext;
3843
import org.elasticsearch.tasks.CancellableTask;
44+
import org.elasticsearch.tasks.TaskInfo;
3945
import org.elasticsearch.test.AbstractMultiClustersTestCase;
4046
import org.elasticsearch.test.InternalTestCluster;
4147
import org.elasticsearch.test.NodeRoles;
4248
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
4349
import org.elasticsearch.transport.TransportService;
50+
import org.hamcrest.Matchers;
4451
import org.junit.Before;
4552

4653
import java.util.Collection;
@@ -146,6 +153,70 @@ public void testProxyConnectionDisconnect() throws Exception {
146153
}
147154
}
148155

156+
public void testCancel() throws Exception {
157+
assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo"));
158+
indexDocs(client(LOCAL_CLUSTER), "demo");
159+
final InternalTestCluster remoteCluster = cluster("cluster_a");
160+
remoteCluster.ensureAtLeastNumDataNodes(1);
161+
final Settings.Builder allocationFilter = Settings.builder();
162+
if (randomBoolean()) {
163+
remoteCluster.ensureAtLeastNumDataNodes(3);
164+
List<String> remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false)
165+
.filter(DiscoveryNode::isDataNode)
166+
.map(DiscoveryNode::getName)
167+
.collect(Collectors.toList());
168+
assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(3));
169+
List<String> seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes);
170+
disconnectFromRemoteClusters();
171+
configureRemoteCluster("cluster_a", seedNodes);
172+
if (randomBoolean()) {
173+
// Using proxy connections
174+
allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes));
175+
} else {
176+
allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes));
177+
}
178+
}
179+
assertAcked(client("cluster_a").admin().indices().prepareCreate("prod")
180+
.setSettings(Settings.builder().put(allocationFilter.build()).put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)));
181+
assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod")
182+
.setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut());
183+
indexDocs(client("cluster_a"), "prod");
184+
SearchListenerPlugin.blockQueryPhase();
185+
PlainActionFuture<SearchResponse> queryFuture = new PlainActionFuture<>();
186+
SearchRequest searchRequest = new SearchRequest("demo", "cluster_a:prod");
187+
searchRequest.allowPartialSearchResults(false);
188+
searchRequest.setCcsMinimizeRoundtrips(false);
189+
searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder()).size(1000));
190+
client(LOCAL_CLUSTER).search(searchRequest, queryFuture);
191+
SearchListenerPlugin.waitSearchStarted();
192+
// Get the search task and cancelled
193+
final TaskInfo rootTask = client().admin().cluster().prepareListTasks()
194+
.setActions(SearchAction.INSTANCE.name())
195+
.get().getTasks().stream().filter(t -> t.getParentTaskId().isSet() == false)
196+
.findFirst().get();
197+
final CancelTasksRequest cancelRequest = new CancelTasksRequest().setTaskId(rootTask.getTaskId());
198+
cancelRequest.setWaitForCompletion(randomBoolean());
199+
final ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().cancelTasks(cancelRequest);
200+
assertBusy(() -> {
201+
final Iterable<TransportService> transportServices = cluster("cluster_a").getInstances(TransportService.class);
202+
for (TransportService transportService : transportServices) {
203+
Collection<CancellableTask> cancellableTasks = transportService.getTaskManager().getCancellableTasks().values();
204+
for (CancellableTask cancellableTask : cancellableTasks) {
205+
assertTrue(cancellableTask.getDescription(), cancellableTask.isCancelled());
206+
}
207+
}
208+
});
209+
SearchListenerPlugin.allowQueryPhase();
210+
assertBusy(() -> assertTrue(queryFuture.isDone()));
211+
assertBusy(() -> assertTrue(cancelFuture.isDone()));
212+
assertBusy(() -> {
213+
final Iterable<TransportService> transportServices = cluster("cluster_a").getInstances(TransportService.class);
214+
for (TransportService transportService : transportServices) {
215+
assertThat(transportService.getTaskManager().getBannedTaskIds(), Matchers.empty());
216+
}
217+
});
218+
}
219+
149220
@Override
150221
protected Collection<Class<? extends Plugin>> nodePlugins(String clusterAlias) {
151222
if (clusterAlias.equals(LOCAL_CLUSTER)) {

server/src/main/java/org/elasticsearch/action/search/SearchRequest.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,21 +140,25 @@ public SearchRequest(String[] indices, SearchSourceBuilder source) {
140140
* Used when a {@link SearchRequest} is created and executed as part of a cross-cluster search request
141141
* performing reduction on each cluster in order to minimize network round-trips between the coordinating node and the remote clusters.
142142
*
143+
* @param parentTaskId the parent taskId of the original search request
143144
* @param originalSearchRequest the original search request
144145
* @param indices the indices to search against
145146
* @param clusterAlias the alias to prefix index names with in the returned search results
146147
* @param absoluteStartMillis the absolute start time to be used on the remote clusters to ensure that the same value is used
147148
* @param finalReduce whether the reduction should be final or not
148149
*/
149-
static SearchRequest subSearchRequest(SearchRequest originalSearchRequest, String[] indices,
150+
static SearchRequest subSearchRequest(TaskId parentTaskId, SearchRequest originalSearchRequest, String[] indices,
150151
String clusterAlias, long absoluteStartMillis, boolean finalReduce) {
152+
Objects.requireNonNull(parentTaskId, "parentTaskId must be specified");
151153
Objects.requireNonNull(originalSearchRequest, "search request must not be null");
152154
validateIndices(indices);
153155
Objects.requireNonNull(clusterAlias, "cluster alias must not be null");
154156
if (absoluteStartMillis < 0) {
155157
throw new IllegalArgumentException("absoluteStartMillis must not be negative but was [" + absoluteStartMillis + "]");
156158
}
157-
return new SearchRequest(originalSearchRequest, indices, clusterAlias, absoluteStartMillis, finalReduce);
159+
final SearchRequest request = new SearchRequest(originalSearchRequest, indices, clusterAlias, absoluteStartMillis, finalReduce);
160+
request.setParentTask(parentTaskId);
161+
return request;
158162
}
159163

160164
private SearchRequest(SearchRequest searchRequest, String[] indices, String localClusterAlias, long absoluteStartMillis,
@@ -304,7 +308,7 @@ boolean isFinalReduce() {
304308
/**
305309
* Returns the current time in milliseconds from the time epoch, to be used for the execution of this search request. Used to
306310
* ensure that the same value, determined by the coordinating node, is used on all nodes involved in the execution of the search
307-
* request. When created through {@link #subSearchRequest(SearchRequest, String[], String, long, boolean)}, this method returns
311+
* request. When created through {@link #subSearchRequest(TaskId, SearchRequest, String[], String, long, boolean)}, this method returns
308312
* the provided current time, otherwise it will return {@link System#currentTimeMillis()}.
309313
*/
310314
long getOrCreateAbsoluteStartMillis() {

server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.elasticsearch.search.profile.ProfileShardResult;
6767
import org.elasticsearch.search.profile.SearchProfileShardResults;
6868
import org.elasticsearch.tasks.Task;
69+
import org.elasticsearch.tasks.TaskId;
6970
import org.elasticsearch.threadpool.ThreadPool;
7071
import org.elasticsearch.transport.RemoteClusterAware;
7172
import org.elasticsearch.transport.RemoteClusterService;
@@ -295,7 +296,8 @@ private void executeRequest(Task task, SearchRequest searchRequest,
295296
task, timeProvider, searchRequest, localIndices, clusterState, listener, searchContext, searchAsyncActionProvider);
296297
} else {
297298
if (shouldMinimizeRoundtrips(searchRequest)) {
298-
ccsRemoteReduce(searchRequest, localIndices, remoteClusterIndices, timeProvider,
299+
final TaskId parentTaskId = task.taskInfo(clusterService.localNode().getId(), false).getTaskId();
300+
ccsRemoteReduce(parentTaskId, searchRequest, localIndices, remoteClusterIndices, timeProvider,
299301
searchService.aggReduceContextBuilder(searchRequest),
300302
remoteClusterService, threadPool, listener,
301303
(r, l) -> executeLocalSearch(
@@ -357,8 +359,9 @@ static boolean shouldMinimizeRoundtrips(SearchRequest searchRequest) {
357359
source.collapse().getInnerHits().isEmpty();
358360
}
359361

360-
static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIndices, Map<String, OriginalIndices> remoteIndices,
361-
SearchTimeProvider timeProvider, InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
362+
static void ccsRemoteReduce(TaskId parentTaskId, SearchRequest searchRequest, OriginalIndices localIndices,
363+
Map<String, OriginalIndices> remoteIndices, SearchTimeProvider timeProvider,
364+
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
362365
RemoteClusterService remoteClusterService, ThreadPool threadPool, ActionListener<SearchResponse> listener,
363366
BiConsumer<SearchRequest, ActionListener<SearchResponse>> localSearchConsumer) {
364367

@@ -369,7 +372,7 @@ static void ccsRemoteReduce(SearchRequest searchRequest, OriginalIndices localIn
369372
String clusterAlias = entry.getKey();
370373
boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
371374
OriginalIndices indices = entry.getValue();
372-
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(),
375+
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, indices.indices(),
373376
clusterAlias, timeProvider.getAbsoluteStartMillis(), true);
374377
Client remoteClusterClient = remoteClusterService.getRemoteClusterClient(threadPool, clusterAlias);
375378
remoteClusterClient.search(ccsSearchRequest, new ActionListener<SearchResponse>() {
@@ -407,7 +410,7 @@ public void onFailure(Exception e) {
407410
String clusterAlias = entry.getKey();
408411
boolean skipUnavailable = remoteClusterService.isSkipUnavailable(clusterAlias);
409412
OriginalIndices indices = entry.getValue();
410-
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(searchRequest, indices.indices(),
413+
SearchRequest ccsSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, indices.indices(),
411414
clusterAlias, timeProvider.getAbsoluteStartMillis(), false);
412415
ActionListener<SearchResponse> ccsListener = createCCSListener(clusterAlias, skipUnavailable, countDown,
413416
skippedClusters, exceptions, searchResponseMerger, totalClusters, listener);
@@ -417,7 +420,7 @@ public void onFailure(Exception e) {
417420
if (localIndices != null) {
418421
ActionListener<SearchResponse> ccsListener = createCCSListener(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
419422
false, countDown, skippedClusters, exceptions, searchResponseMerger, totalClusters, listener);
420-
SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(searchRequest, localIndices.indices(),
423+
SearchRequest ccsLocalSearchRequest = SearchRequest.subSearchRequest(parentTaskId, searchRequest, localIndices.indices(),
421424
RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, timeProvider.getAbsoluteStartMillis(), false);
422425
localSearchConsumer.accept(ccsLocalSearchRequest, ccsListener);
423426
}

server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ private void setBanOnChildConnections(String reason, boolean waitForCompletion,
145145
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(listener.map(r -> null), childConnections.size());
146146
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
147147
for (Transport.Connection connection : childConnections) {
148+
assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
148149
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, banRequest, TransportRequestOptions.EMPTY,
149150
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
150151
@Override
@@ -167,6 +168,7 @@ private void removeBanOnChildConnections(CancellableTask task, Collection<Transp
167168
final BanParentTaskRequest request =
168169
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(localNodeId(), task.getId()));
169170
for (Transport.Connection connection : childConnections) {
171+
assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
170172
logger.trace("Sending remove ban for tasks with the parent [{}] for connection [{}]", request.parentTaskId, connection);
171173
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, request, TransportRequestOptions.EMPTY,
172174
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {

0 commit comments

Comments
 (0)