Skip to content

Commit 7e9153b

Browse files
authored
Expose the logic to cancel task when the rest channel is closed (#51423)
This commit moves the logic that cancels search requests when the rest channel is closed to a generic client that can be used by other APIs. This will be useful for any rest action that wants to cancel the execution of a task if the underlying rest channel is closed by the client before completion. Relates #49931 Relates #50990 Relates #50990
1 parent 9d2c579 commit 7e9153b

File tree

4 files changed

+105
-77
lines changed

4 files changed

+105
-77
lines changed

server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java renamed to server/src/main/java/org/elasticsearch/rest/action/RestCancellableNodeClient.java

Lines changed: 75 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,54 +17,84 @@
1717
* under the License.
1818
*/
1919

20-
package org.elasticsearch.rest.action.search;
20+
package org.elasticsearch.rest.action;
2121

2222
import org.elasticsearch.action.ActionListener;
2323
import org.elasticsearch.action.ActionRequest;
2424
import org.elasticsearch.action.ActionResponse;
2525
import org.elasticsearch.action.ActionType;
2626
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
27-
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
28-
import org.elasticsearch.action.support.ContextPreservingActionListener;
2927
import org.elasticsearch.client.Client;
28+
import org.elasticsearch.client.FilterClient;
29+
import org.elasticsearch.client.OriginSettingClient;
3030
import org.elasticsearch.client.node.NodeClient;
31-
import org.elasticsearch.common.util.concurrent.ThreadContext;
3231
import org.elasticsearch.http.HttpChannel;
3332
import org.elasticsearch.tasks.Task;
3433
import org.elasticsearch.tasks.TaskId;
3534

35+
import java.util.ArrayList;
3636
import java.util.HashSet;
37+
import java.util.List;
3738
import java.util.Map;
3839
import java.util.Set;
3940
import java.util.concurrent.ConcurrentHashMap;
4041
import java.util.concurrent.atomic.AtomicReference;
4142

43+
import static org.elasticsearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
44+
4245
/**
43-
* This class executes a request and associates the corresponding {@link Task} with the {@link HttpChannel} that it was originated from,
44-
* so that the tasks associated with a certain channel get cancelled when the underlying connection gets closed.
46+
* A {@linkplain Client} that cancels tasks executed locally when the provided {@link HttpChannel}
47+
* is closed before completion.
4548
*/
46-
public final class HttpChannelTaskHandler {
49+
public class RestCancellableNodeClient extends FilterClient {
50+
private static final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
51+
52+
private final NodeClient client;
53+
private final HttpChannel httpChannel;
4754

48-
public static final HttpChannelTaskHandler INSTANCE = new HttpChannelTaskHandler();
49-
//package private for testing
50-
final Map<HttpChannel, CloseListener> httpChannels = new ConcurrentHashMap<>();
55+
public RestCancellableNodeClient(NodeClient client, HttpChannel httpChannel) {
56+
super(client);
57+
this.client = client;
58+
this.httpChannel = httpChannel;
59+
}
60+
61+
/**
62+
* Returns the number of channels tracked globally.
63+
*/
64+
public static int getNumChannels() {
65+
return httpChannels.size();
66+
}
5167

52-
private HttpChannelTaskHandler() {
68+
/**
69+
* Returns the number of tasks tracked globally.
70+
*/
71+
static int getNumTasks() {
72+
return httpChannels.values().stream()
73+
.mapToInt(CloseListener::getNumTasks)
74+
.sum();
5375
}
5476

55-
<Response extends ActionResponse> void execute(NodeClient client, HttpChannel httpChannel, ActionRequest request,
56-
ActionType<Response> actionType, ActionListener<Response> listener) {
77+
/**
78+
* Returns the number of tasks tracked by the provided {@link HttpChannel}.
79+
*/
80+
static int getNumTasks(HttpChannel channel) {
81+
CloseListener listener = httpChannels.get(channel);
82+
return listener == null ? 0 : listener.getNumTasks();
83+
}
5784

58-
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener(client));
85+
@Override
86+
public <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
87+
ActionType<Response> action, Request request, ActionListener<Response> listener) {
88+
CloseListener closeListener = httpChannels.computeIfAbsent(httpChannel, channel -> new CloseListener());
5989
TaskHolder taskHolder = new TaskHolder();
60-
Task task = client.executeLocally(actionType, request,
90+
Task task = client.executeLocally(action, request,
6191
new ActionListener<>() {
6292
@Override
63-
public void onResponse(Response searchResponse) {
93+
public void onResponse(Response response) {
6494
try {
6595
closeListener.unregisterTask(taskHolder);
6696
} finally {
67-
listener.onResponse(searchResponse);
97+
listener.onResponse(response);
6898
}
6999
}
70100

@@ -77,32 +107,35 @@ public void onFailure(Exception e) {
77107
}
78108
}
79109
});
80-
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId()));
110+
final TaskId taskId = new TaskId(client.getLocalNodeId(), task.getId());
111+
closeListener.registerTask(taskHolder, taskId);
81112
closeListener.maybeRegisterChannel(httpChannel);
82113
}
83114

84-
public int getNumChannels() {
85-
return httpChannels.size();
115+
private void cancelTask(TaskId taskId) {
116+
CancelTasksRequest req = new CancelTasksRequest()
117+
.setTaskId(taskId)
118+
.setReason("channel closed");
119+
// force the origin to execute the cancellation as a system user
120+
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster().cancelTasks(req, ActionListener.wrap(() -> {}));
86121
}
87122

88-
final class CloseListener implements ActionListener<Void> {
89-
private final Client client;
123+
private class CloseListener implements ActionListener<Void> {
90124
private final AtomicReference<HttpChannel> channel = new AtomicReference<>();
91-
private final Set<TaskId> taskIds = new HashSet<>();
125+
private final Set<TaskId> tasks = new HashSet<>();
92126

93-
CloseListener(Client client) {
94-
this.client = client;
127+
CloseListener() {
95128
}
96129

97-
int getNumTasks() {
98-
return taskIds.size();
130+
synchronized int getNumTasks() {
131+
return tasks.size();
99132
}
100133

101134
void maybeRegisterChannel(HttpChannel httpChannel) {
102135
if (channel.compareAndSet(null, httpChannel)) {
103136
//In case the channel is already closed when we register the listener, the listener will be immediately executed which will
104137
//remove the channel from the map straight-away. That is why we first create the CloseListener and later we associate it
105-
//with the channel. This guarantees that the close listener is already in the map when the it gets registered to its
138+
//with the channel. This guarantees that the close listener is already in the map when it gets registered to its
106139
//corresponding channel, hence it is always found in the map when it gets invoked if the channel gets closed.
107140
httpChannel.addCloseListener(this);
108141
}
@@ -111,34 +144,31 @@ void maybeRegisterChannel(HttpChannel httpChannel) {
111144
synchronized void registerTask(TaskHolder taskHolder, TaskId taskId) {
112145
taskHolder.taskId = taskId;
113146
if (taskHolder.completed == false) {
114-
this.taskIds.add(taskId);
147+
this.tasks.add(taskId);
115148
}
116149
}
117150

118151
synchronized void unregisterTask(TaskHolder taskHolder) {
119152
if (taskHolder.taskId != null) {
120-
this.taskIds.remove(taskHolder.taskId);
153+
this.tasks.remove(taskHolder.taskId);
121154
}
122155
taskHolder.completed = true;
123156
}
124157

125158
@Override
126-
public synchronized void onResponse(Void aVoid) {
127-
//When the channel gets closed it won't be reused: we can remove it from the map and forget about it.
128-
CloseListener closeListener = httpChannels.remove(channel.get());
159+
public void onResponse(Void aVoid) {
160+
final HttpChannel httpChannel = channel.get();
161+
assert httpChannel != null : "channel not registered";
162+
// when the channel gets closed it won't be reused: we can remove it from the map and forget about it.
163+
CloseListener closeListener = httpChannels.remove(httpChannel);
129164
assert closeListener != null : "channel not found in the map of tracked channels";
130-
for (TaskId taskId : taskIds) {
131-
ThreadContext threadContext = client.threadPool().getThreadContext();
132-
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
133-
// we stash any context here since this is an internal execution and should not leak any existing context information
134-
threadContext.markAsSystemContext();
135-
ContextPreservingActionListener<CancelTasksResponse> contextPreservingListener = new ContextPreservingActionListener<>(
136-
threadContext.newRestorableContext(false), ActionListener.wrap(r -> {}, e -> {}));
137-
CancelTasksRequest cancelTasksRequest = new CancelTasksRequest();
138-
cancelTasksRequest.setTaskId(taskId);
139-
//We don't wait for cancel tasks to come back. Task cancellation is just best effort.
140-
client.admin().cluster().cancelTasks(cancelTasksRequest, contextPreservingListener);
141-
}
165+
final List<TaskId> toCancel;
166+
synchronized (this) {
167+
toCancel = new ArrayList<>(tasks);
168+
tasks.clear();
169+
}
170+
for (TaskId taskId : toCancel) {
171+
cancelTask(taskId);
142172
}
143173
}
144174

server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import org.elasticsearch.action.search.SearchAction;
2323
import org.elasticsearch.action.search.SearchRequest;
24-
import org.elasticsearch.action.search.SearchResponse;
2524
import org.elasticsearch.action.support.IndicesOptions;
2625
import org.elasticsearch.client.node.NodeClient;
2726
import org.elasticsearch.common.Booleans;
@@ -32,6 +31,7 @@
3231
import org.elasticsearch.rest.RestController;
3332
import org.elasticsearch.rest.RestRequest;
3433
import org.elasticsearch.rest.action.RestActions;
34+
import org.elasticsearch.rest.action.RestCancellableNodeClient;
3535
import org.elasticsearch.rest.action.RestStatusToXContentListener;
3636
import org.elasticsearch.search.Scroll;
3737
import org.elasticsearch.search.builder.SearchSourceBuilder;
@@ -100,8 +100,8 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
100100
parseSearchRequest(searchRequest, request, parser, setSize));
101101

102102
return channel -> {
103-
RestStatusToXContentListener<SearchResponse> listener = new RestStatusToXContentListener<>(channel);
104-
HttpChannelTaskHandler.INSTANCE.execute(client, request.getHttpChannel(), searchRequest, SearchAction.INSTANCE, listener);
103+
RestCancellableNodeClient cancelClient = new RestCancellableNodeClient(client, request.getHttpChannel());
104+
cancelClient.execute(SearchAction.INSTANCE, searchRequest, new RestStatusToXContentListener<>(channel));
105105
};
106106
}
107107

server/src/test/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandlerTests.java renamed to server/src/test/java/org/elasticsearch/rest/action/RestCancellableNodeClientTests.java

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
* under the License.
1818
*/
1919

20-
package org.elasticsearch.rest.action.search;
20+
package org.elasticsearch.rest.action;
2121

2222
import org.elasticsearch.action.ActionListener;
2323
import org.elasticsearch.action.ActionRequest;
@@ -45,7 +45,6 @@
4545
import java.util.ArrayList;
4646
import java.util.Collections;
4747
import java.util.List;
48-
import java.util.Map;
4948
import java.util.Set;
5049
import java.util.concurrent.CopyOnWriteArraySet;
5150
import java.util.concurrent.CountDownLatch;
@@ -56,13 +55,13 @@
5655
import java.util.concurrent.atomic.AtomicLong;
5756
import java.util.concurrent.atomic.AtomicReference;
5857

59-
public class HttpChannelTaskHandlerTests extends ESTestCase {
58+
public class RestCancellableNodeClientTests extends ESTestCase {
6059

6160
private ThreadPool threadPool;
6261

6362
@Before
6463
public void createThreadPool() {
65-
threadPool = new TestThreadPool(HttpChannelTaskHandlerTests.class.getName());
64+
threadPool = new TestThreadPool(RestCancellableNodeClientTests.class.getName());
6665
}
6766

6867
@After
@@ -77,8 +76,7 @@ public void stopThreadPool() {
7776
*/
7877
public void testCompletedTasks() throws Exception {
7978
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, false)) {
80-
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
81-
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
79+
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
8280
int totalSearches = 0;
8381
List<Future<?>> futures = new ArrayList<>();
8482
int numChannels = randomIntBetween(1, 30);
@@ -88,19 +86,17 @@ public void testCompletedTasks() throws Exception {
8886
totalSearches += numTasks;
8987
for (int j = 0; j < numTasks; j++) {
9088
PlainListenableActionFuture<SearchResponse> actionFuture = PlainListenableActionFuture.newListenableFuture();
91-
threadPool.generic().submit(() -> httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(),
92-
SearchAction.INSTANCE, actionFuture));
89+
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
90+
threadPool.generic().submit(() -> client.execute(SearchAction.INSTANCE, new SearchRequest(), actionFuture));
9391
futures.add(actionFuture);
9492
}
9593
}
9694
for (Future<?> future : futures) {
9795
future.get();
9896
}
9997
//no channels get closed in this test, hence we expect as many channels as we created in the map
100-
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
101-
for (Map.Entry<HttpChannel, HttpChannelTaskHandler.CloseListener> entry : httpChannelTaskHandler.httpChannels.entrySet()) {
102-
assertEquals(0, entry.getValue().getNumTasks());
103-
}
98+
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
99+
assertEquals(0, RestCancellableNodeClient.getNumTasks());
104100
assertEquals(totalSearches, testClient.searchRequests.get());
105101
}
106102
}
@@ -110,9 +106,8 @@ public void testCompletedTasks() throws Exception {
110106
* removed and all of its corresponding tasks get cancelled.
111107
*/
112108
public void testCancelledTasks() throws Exception {
113-
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
114-
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
115-
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
109+
try (TestClient nodeClient = new TestClient(Settings.EMPTY, threadPool, true)) {
110+
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
116111
int numChannels = randomIntBetween(1, 30);
117112
int totalSearches = 0;
118113
List<TestHttpChannel> channels = new ArrayList<>(numChannels);
@@ -121,18 +116,19 @@ public void testCancelledTasks() throws Exception {
121116
channels.add(channel);
122117
int numTasks = randomIntBetween(1, 30);
123118
totalSearches += numTasks;
119+
RestCancellableNodeClient client = new RestCancellableNodeClient(nodeClient, channel);
124120
for (int j = 0; j < numTasks; j++) {
125-
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
121+
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
126122
}
127-
assertEquals(numTasks, httpChannelTaskHandler.httpChannels.get(channel).getNumTasks());
123+
assertEquals(numTasks, RestCancellableNodeClient.getNumTasks(channel));
128124
}
129-
assertEquals(initialHttpChannels + numChannels, httpChannelTaskHandler.getNumChannels());
125+
assertEquals(initialHttpChannels + numChannels, RestCancellableNodeClient.getNumChannels());
130126
for (TestHttpChannel channel : channels) {
131127
channel.awaitClose();
132128
}
133-
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
134-
assertEquals(totalSearches, testClient.searchRequests.get());
135-
assertEquals(totalSearches, testClient.cancelledTasks.size());
129+
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
130+
assertEquals(totalSearches, nodeClient.searchRequests.get());
131+
assertEquals(totalSearches, nodeClient.cancelledTasks.size());
136132
}
137133
}
138134

@@ -144,8 +140,7 @@ public void testCancelledTasks() throws Exception {
144140
*/
145141
public void testChannelAlreadyClosed() {
146142
try (TestClient testClient = new TestClient(Settings.EMPTY, threadPool, true)) {
147-
HttpChannelTaskHandler httpChannelTaskHandler = HttpChannelTaskHandler.INSTANCE;
148-
int initialHttpChannels = httpChannelTaskHandler.getNumChannels();
143+
int initialHttpChannels = RestCancellableNodeClient.getNumChannels();
149144
int numChannels = randomIntBetween(1, 30);
150145
int totalSearches = 0;
151146
for (int i = 0; i < numChannels; i++) {
@@ -154,12 +149,13 @@ public void testChannelAlreadyClosed() {
154149
channel.close();
155150
int numTasks = randomIntBetween(1, 5);
156151
totalSearches += numTasks;
152+
RestCancellableNodeClient client = new RestCancellableNodeClient(testClient, channel);
157153
for (int j = 0; j < numTasks; j++) {
158154
//here the channel will be first registered, then straight-away removed from the map as the close listener is invoked
159-
httpChannelTaskHandler.execute(testClient, channel, new SearchRequest(), SearchAction.INSTANCE, null);
155+
client.execute(SearchAction.INSTANCE, new SearchRequest(), null);
160156
}
161157
}
162-
assertEquals(initialHttpChannels, httpChannelTaskHandler.getNumChannels());
158+
assertEquals(initialHttpChannels, RestCancellableNodeClient.getNumChannels());
163159
assertEquals(totalSearches, testClient.searchRequests.get());
164160
assertEquals(totalSearches, testClient.cancelledTasks.size());
165161
}

test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113
import org.elasticsearch.plugins.NetworkPlugin;
114114
import org.elasticsearch.plugins.Plugin;
115115
import org.elasticsearch.rest.RestStatus;
116-
import org.elasticsearch.rest.action.search.HttpChannelTaskHandler;
116+
import org.elasticsearch.rest.action.RestCancellableNodeClient;
117117
import org.elasticsearch.script.ScriptService;
118118
import org.elasticsearch.search.MockSearchService;
119119
import org.elasticsearch.search.SearchHit;
@@ -511,9 +511,11 @@ private static void clearClusters() throws Exception {
511511
restClient.close();
512512
restClient = null;
513513
}
514-
assertBusy(() -> assertEquals(HttpChannelTaskHandler.INSTANCE.getNumChannels() + " channels still being tracked in " +
515-
HttpChannelTaskHandler.class.getSimpleName() + " while there should be none", 0,
516-
HttpChannelTaskHandler.INSTANCE.getNumChannels()));
514+
assertBusy(() -> {
515+
int numChannels = RestCancellableNodeClient.getNumChannels();
516+
assertEquals( numChannels+ " channels still being tracked in " + RestCancellableNodeClient.class.getSimpleName()
517+
+ " while there should be none", 0, numChannels);
518+
});
517519
}
518520

519521
private void afterInternal(boolean afterClass) throws Exception {

0 commit comments

Comments
 (0)