Skip to content

Commit 2257f8c

Browse files
authored
Add support for task cancellation to TransportNodesAction (elastic#71895)
Backport of elastic#71695
1 parent 4d04199 commit 2257f8c

File tree

5 files changed

+65
-15
lines changed

5 files changed

+65
-15
lines changed

server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.elasticsearch.plugins.Plugin;
4444
import org.elasticsearch.search.builder.SearchSourceBuilder;
4545
import org.elasticsearch.tasks.Task;
46+
import org.elasticsearch.tasks.TaskCancelledException;
4647
import org.elasticsearch.tasks.TaskId;
4748
import org.elasticsearch.tasks.TaskInfo;
4849
import org.elasticsearch.tasks.TaskResult;
@@ -496,7 +497,7 @@ public void testTasksCancellation() throws Exception {
496497
.setActions(TestTaskPlugin.TestTaskAction.NAME).get();
497498
assertEquals(1, cancelTasksResponse.getTasks().size());
498499

499-
future.get();
500+
expectThrows(TaskCancelledException.class, future::actionGet);
500501

501502
logger.info("--> checking that test tasks are not running");
502503
assertEquals(0,

server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import org.elasticsearch.cluster.service.ClusterService;
1919
import org.elasticsearch.common.io.stream.StreamInput;
2020
import org.elasticsearch.common.io.stream.Writeable;
21+
import org.elasticsearch.tasks.CancellableTask;
2122
import org.elasticsearch.tasks.Task;
23+
import org.elasticsearch.tasks.TaskCancelledException;
2224
import org.elasticsearch.threadpool.ThreadPool;
2325
import org.elasticsearch.transport.TransportChannel;
2426
import org.elasticsearch.transport.TransportException;
@@ -203,9 +205,7 @@ class AsyncAction {
203205
void start() {
204206
final DiscoveryNode[] nodes = request.concreteNodes();
205207
if (nodes.length == 0) {
206-
// nothing to notify, so respond immediately, but always fork even if finalExecutor == SAME
207-
final String executor = finalExecutor.equals(ThreadPool.Names.SAME) ? ThreadPool.Names.GENERIC : finalExecutor;
208-
threadPool.executor(executor).execute(() -> newResponse(task, request, responses, listener));
208+
finishHim();
209209
return;
210210
}
211211
final TransportRequestOptions transportRequestOptions = TransportRequestOptions.timeout(request.timeout());
@@ -258,14 +258,27 @@ private void onFailure(int idx, String nodeId, Throwable t) {
258258
}
259259

260260
private void finishHim() {
261-
threadPool.executor(finalExecutor).execute(() -> newResponse(task, request, responses, listener));
261+
if (isCancelled(task)) {
262+
listener.onFailure(new TaskCancelledException("task cancelled"));
263+
return;
264+
}
265+
266+
final String executor = finalExecutor.equals(ThreadPool.Names.SAME) ? ThreadPool.Names.GENERIC : finalExecutor;
267+
threadPool.executor(executor).execute(() -> newResponse(task, request, responses, listener));
262268
}
263269
}
264270

265-
class NodeTransportHandler implements TransportRequestHandler<NodeRequest> {
271+
private boolean isCancelled(Task task) {
272+
return task instanceof CancellableTask && ((CancellableTask) task).isCancelled();
273+
}
266274

275+
class NodeTransportHandler implements TransportRequestHandler<NodeRequest> {
267276
@Override
268277
public void messageReceived(NodeRequest request, TransportChannel channel, Task task) throws Exception {
278+
if (isCancelled(task)) {
279+
throw new TaskCancelledException("task cancelled");
280+
}
281+
269282
channel.sendResponse(nodeOperation(request, task));
270283
}
271284
}

server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import static org.elasticsearch.test.ClusterServiceUtils.setState;
5555
import static org.hamcrest.Matchers.equalTo;
5656
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
57+
import static org.hamcrest.Matchers.instanceOf;
5758
import static org.hamcrest.Matchers.lessThanOrEqualTo;
5859
import static org.hamcrest.Matchers.startsWith;
5960

@@ -224,7 +225,9 @@ public void testBasicTaskCancellation() throws Exception {
224225
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
225226
final AtomicReference<Throwable> throwableReference = new AtomicReference<>();
226227
int runNodesCount = randomIntBetween(1, nodesCount);
227-
int blockedNodesCount = randomIntBetween(0, runNodesCount);
228+
// Block at least 1 node, otherwise it's quite easy to end up in a race condition where the node tasks
229+
// have finished before the cancel request has arrived
230+
int blockedNodesCount = randomIntBetween(1, runNodesCount);
228231
Task mainTask = startCancellableTestNodesAction(waitForActionToStart, runNodesCount, blockedNodesCount,
229232
new ActionListener<NodesResponse>() {
230233
@Override
@@ -261,12 +264,7 @@ public void onFailure(Exception e) {
261264
assertEquals(runNodesCount, responseReference.get().getNodes().size());
262265
assertEquals(0, responseReference.get().failureCount());
263266
} else {
264-
// We canceled the request, in this case it should have fail, but we should get partial response
265-
assertNull(throwableReference.get());
266-
assertEquals(runNodesCount, responseReference.get().failureCount() + responseReference.get().getNodes().size());
267-
// and we should have at least as many failures as the number of blocked operations
268-
// (we might have cancelled some non-blocked operations before they even started and that's ok)
269-
assertThat(responseReference.get().failureCount(), greaterThanOrEqualTo(blockedNodesCount));
267+
assertThat(throwableReference.get(), instanceOf(TaskCancelledException.class));
270268

271269
// We should have the information about the cancelled task in the cancel operation response
272270
assertEquals(1, response.getTasks().size());

server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/TaskManagerTestCase.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ protected NodeResponse newNodeResponse(StreamInput in) throws IOException {
154154

155155
@Override
156156
protected abstract NodeResponse nodeOperation(NodeRequest request);
157-
158157
}
159158

160159
public static class TestNode implements Releasable {

server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import org.elasticsearch.common.io.stream.StreamInput;
2323
import org.elasticsearch.common.io.stream.StreamOutput;
2424
import org.elasticsearch.common.io.stream.Writeable;
25+
import org.elasticsearch.tasks.CancellableTask;
2526
import org.elasticsearch.tasks.Task;
27+
import org.elasticsearch.tasks.TaskCancelledException;
2628
import org.elasticsearch.test.ESTestCase;
2729
import org.elasticsearch.test.transport.CapturingTransport;
2830
import org.elasticsearch.threadpool.TestThreadPool;
@@ -41,6 +43,7 @@
4143
import java.util.List;
4244
import java.util.Map;
4345
import java.util.Set;
46+
import java.util.concurrent.ExecutionException;
4447
import java.util.concurrent.TimeUnit;
4548
import java.util.concurrent.atomic.AtomicReferenceArray;
4649
import java.util.function.Supplier;
@@ -138,6 +141,37 @@ public void testCustomResolving() throws Exception {
138141
assertEquals(clusterService.state().nodes().getDataNodes().size(), capturedRequests.size());
139142
}
140143

144+
public void testTaskCancellationThrowsException() {
145+
TransportNodesAction<TestNodesRequest, TestNodesResponse, TestNodeRequest, TestNodeResponse> action = getTestTransportNodesAction();
146+
List<String> nodeIds = new ArrayList<>();
147+
for (DiscoveryNode node : clusterService.state().nodes()) {
148+
nodeIds.add(node.getId());
149+
}
150+
151+
TestNodesRequest request = new TestNodesRequest(nodeIds.toArray(new String[0]));
152+
PlainActionFuture<TestNodesResponse> listener = new PlainActionFuture<>();
153+
Task cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()) {
154+
@Override
155+
public boolean isCancelled() {
156+
return true;
157+
}
158+
};
159+
action.doExecute(cancellableTask, request, listener);
160+
Map<String, List<CapturingTransport.CapturedRequest>> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear();
161+
for (List<CapturingTransport.CapturedRequest> requests : capturedRequests.values()) {
162+
for (CapturingTransport.CapturedRequest capturedRequest : requests) {
163+
if (randomBoolean()) {
164+
transport.handleResponse(capturedRequest.requestId, new TestNodeResponse(capturedRequest.node));
165+
} else {
166+
transport.handleRemoteError(capturedRequest.requestId, new TaskCancelledException("simulated"));
167+
}
168+
}
169+
}
170+
171+
assertTrue(listener.isDone());
172+
expectThrows(ExecutionException.class, TaskCancelledException.class, listener::get);
173+
}
174+
141175
private <T> List<T> mockList(Supplier<T> supplier, int size) {
142176
List<T> failures = new ArrayList<>(size);
143177
for (int i = 0; i < size; ++i) {
@@ -321,8 +355,13 @@ private static class TestNodeRequest extends BaseNodeRequest {
321355

322356
private static class TestNodeResponse extends BaseNodeResponse {
323357
TestNodeResponse() {
324-
super(mock(DiscoveryNode.class));
358+
this(mock(DiscoveryNode.class));
325359
}
360+
361+
TestNodeResponse(DiscoveryNode node) {
362+
super(node);
363+
}
364+
326365
protected TestNodeResponse(StreamInput in) throws IOException {
327366
super(in);
328367
}

0 commit comments

Comments
 (0)