Skip to content

Commit d4655e8

Browse files
authored
Discard intermediate node results when a request is cancelled (#82685)
Resolves #82337
1 parent d4caeea commit d4655e8

File tree

9 files changed

+341
-79
lines changed

9 files changed

+341
-79
lines changed

docs/changelog/82685.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 82685
2+
summary: Discard intermediate results upon cancellation for stats endpoints
3+
area: Stats
4+
type: bug
5+
issues:
6+
- 82337
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.action.support;
10+
11+
import java.util.Collection;
12+
import java.util.concurrent.atomic.AtomicInteger;
13+
import java.util.concurrent.atomic.AtomicReferenceArray;
14+
15+
/**
16+
* This class tracks the intermediate responses that will be used to create aggregated cluster response to a request. It also gives the
17+
* possibility to discard the intermediate results when asked, for example when the initial request is cancelled, in order to release the
18+
* resources.
19+
*/
20+
public class NodeResponseTracker {
21+
22+
private final AtomicInteger counter = new AtomicInteger();
23+
private final int expectedResponsesCount;
24+
private volatile AtomicReferenceArray<Object> responses;
25+
private volatile Exception causeOfDiscarding;
26+
27+
public NodeResponseTracker(int size) {
28+
this.expectedResponsesCount = size;
29+
this.responses = new AtomicReferenceArray<>(size);
30+
}
31+
32+
public NodeResponseTracker(Collection<Object> array) {
33+
this.expectedResponsesCount = array.size();
34+
this.responses = new AtomicReferenceArray<>(array.toArray());
35+
}
36+
37+
/**
38+
* This method discards the results collected so far to free up the resources.
39+
* @param cause the discarding, this will be communicated if they try to access the discarded results
40+
*/
41+
public void discardIntermediateResponses(Exception cause) {
42+
if (responses != null) {
43+
this.causeOfDiscarding = cause;
44+
responses = null;
45+
}
46+
}
47+
48+
public boolean responsesDiscarded() {
49+
return responses == null;
50+
}
51+
52+
/**
53+
* This method stores a new node response if the intermediate responses haven't been discarded yet. If the responses are not discarded
54+
* the method asserts that this is the first response encountered from this node to protect from miscounting the responses in case of a
55+
* double invocation. If the responses have been discarded we accept this risk for simplicity.
56+
* @param nodeIndex, the index that represents a single node of the cluster
57+
* @param response, a response can be either a NodeResponse or an error
58+
* @return true if all the nodes' responses have been received, else false
59+
*/
60+
public boolean trackResponseAndCheckIfLast(int nodeIndex, Object response) {
61+
AtomicReferenceArray<Object> responses = this.responses;
62+
63+
if (responsesDiscarded() == false) {
64+
boolean firstEncounter = responses.compareAndSet(nodeIndex, null, response);
65+
assert firstEncounter : "a response should be tracked only once";
66+
}
67+
return counter.incrementAndGet() == getExpectedResponseCount();
68+
}
69+
70+
/**
71+
* Returns the tracked response or null if the response hasn't been received yet for a specific index that represents a node of the
72+
* cluster.
73+
* @throws DiscardedResponsesException if the responses have been discarded
74+
*/
75+
public Object getResponse(int nodeIndex) throws DiscardedResponsesException {
76+
AtomicReferenceArray<Object> responses = this.responses;
77+
if (responsesDiscarded()) {
78+
throw new DiscardedResponsesException(causeOfDiscarding);
79+
}
80+
return responses.get(nodeIndex);
81+
}
82+
83+
public int getExpectedResponseCount() {
84+
return expectedResponsesCount;
85+
}
86+
87+
/**
88+
* This exception is thrown when the {@link NodeResponseTracker} is asked to give information about the responses after they have been
89+
* discarded.
90+
*/
91+
public static class DiscardedResponsesException extends Exception {
92+
93+
public DiscardedResponsesException(Exception cause) {
94+
super(cause);
95+
}
96+
}
97+
}

server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java

+45-31
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
1717
import org.elasticsearch.action.support.HandledTransportAction;
1818
import org.elasticsearch.action.support.IndicesOptions;
19+
import org.elasticsearch.action.support.NodeResponseTracker;
1920
import org.elasticsearch.action.support.TransportActions;
2021
import org.elasticsearch.action.support.broadcast.BroadcastRequest;
2122
import org.elasticsearch.action.support.broadcast.BroadcastResponse;
@@ -51,7 +52,6 @@
5152
import java.util.List;
5253
import java.util.Map;
5354
import java.util.concurrent.atomic.AtomicInteger;
54-
import java.util.concurrent.atomic.AtomicReferenceArray;
5555
import java.util.function.Consumer;
5656

5757
/**
@@ -118,28 +118,29 @@ public TransportBroadcastByNodeAction(
118118

119119
private Response newResponse(
120120
Request request,
121-
AtomicReferenceArray<?> responses,
121+
NodeResponseTracker nodeResponseTracker,
122122
int unavailableShardCount,
123123
Map<String, List<ShardRouting>> nodes,
124124
ClusterState clusterState
125-
) {
125+
) throws NodeResponseTracker.DiscardedResponsesException {
126126
int totalShards = 0;
127127
int successfulShards = 0;
128128
List<ShardOperationResult> broadcastByNodeResponses = new ArrayList<>();
129129
List<DefaultShardOperationFailedException> exceptions = new ArrayList<>();
130-
for (int i = 0; i < responses.length(); i++) {
131-
if (responses.get(i)instanceof FailedNodeException exception) {
130+
for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); i++) {
131+
Object response = nodeResponseTracker.getResponse(i);
132+
if (response instanceof FailedNodeException exception) {
132133
totalShards += nodes.get(exception.nodeId()).size();
133134
for (ShardRouting shard : nodes.get(exception.nodeId())) {
134135
exceptions.add(new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), exception));
135136
}
136137
} else {
137138
@SuppressWarnings("unchecked")
138-
NodeResponse response = (NodeResponse) responses.get(i);
139-
broadcastByNodeResponses.addAll(response.results);
140-
totalShards += response.getTotalShards();
141-
successfulShards += response.getSuccessfulShards();
142-
for (BroadcastShardOperationFailedException throwable : response.getExceptions()) {
139+
NodeResponse nodeResponse = (NodeResponse) response;
140+
broadcastByNodeResponses.addAll(nodeResponse.results);
141+
totalShards += nodeResponse.getTotalShards();
142+
successfulShards += nodeResponse.getSuccessfulShards();
143+
for (BroadcastShardOperationFailedException throwable : nodeResponse.getExceptions()) {
143144
if (TransportActions.isShardNotAvailableException(throwable) == false) {
144145
exceptions.add(
145146
new DefaultShardOperationFailedException(
@@ -256,16 +257,15 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
256257
new AsyncAction(task, request, listener).start();
257258
}
258259

259-
protected class AsyncAction {
260+
protected class AsyncAction implements CancellableTask.CancellationListener {
260261
private final Task task;
261262
private final Request request;
262263
private final ActionListener<Response> listener;
263264
private final ClusterState clusterState;
264265
private final DiscoveryNodes nodes;
265266
private final Map<String, List<ShardRouting>> nodeIds;
266-
private final AtomicReferenceArray<Object> responses;
267-
private final AtomicInteger counter = new AtomicInteger();
268267
private final int unavailableShardCount;
268+
private final NodeResponseTracker nodeResponseTracker;
269269

270270
protected AsyncAction(Task task, Request request, ActionListener<Response> listener) {
271271
this.task = task;
@@ -312,10 +312,13 @@ protected AsyncAction(Task task, Request request, ActionListener<Response> liste
312312

313313
}
314314
this.unavailableShardCount = unavailableShardCount;
315-
responses = new AtomicReferenceArray<>(nodeIds.size());
315+
nodeResponseTracker = new NodeResponseTracker(nodeIds.size());
316316
}
317317

318318
public void start() {
319+
if (task instanceof CancellableTask cancellableTask) {
320+
cancellableTask.addListener(this);
321+
}
319322
if (nodeIds.size() == 0) {
320323
try {
321324
onCompletion();
@@ -373,38 +376,34 @@ protected void onNodeResponse(DiscoveryNode node, int nodeIndex, NodeResponse re
373376
logger.trace("received response for [{}] from node [{}]", actionName, node.getId());
374377
}
375378

376-
// this is defensive to protect against the possibility of double invocation
377-
// the current implementation of TransportService#sendRequest guards against this
378-
// but concurrency is hard, safety is important, and the small performance loss here does not matter
379-
if (responses.compareAndSet(nodeIndex, null, response)) {
380-
if (counter.incrementAndGet() == responses.length()) {
381-
onCompletion();
382-
}
379+
if (nodeResponseTracker.trackResponseAndCheckIfLast(nodeIndex, response)) {
380+
onCompletion();
383381
}
384382
}
385383

386384
protected void onNodeFailure(DiscoveryNode node, int nodeIndex, Throwable t) {
387385
String nodeId = node.getId();
388386
logger.debug(new ParameterizedMessage("failed to execute [{}] on node [{}]", actionName, nodeId), t);
389-
390-
// this is defensive to protect against the possibility of double invocation
391-
// the current implementation of TransportService#sendRequest guards against this
392-
// but concurrency is hard, safety is important, and the small performance loss here does not matter
393-
if (responses.compareAndSet(nodeIndex, null, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) {
394-
if (counter.incrementAndGet() == responses.length()) {
395-
onCompletion();
396-
}
387+
if (nodeResponseTracker.trackResponseAndCheckIfLast(
388+
nodeIndex,
389+
new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t)
390+
)) {
391+
onCompletion();
397392
}
398393
}
399394

400395
protected void onCompletion() {
401-
if (task instanceof CancellableTask && ((CancellableTask) task).notifyIfCancelled(listener)) {
396+
if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) {
402397
return;
403398
}
404399

405400
Response response = null;
406401
try {
407-
response = newResponse(request, responses, unavailableShardCount, nodeIds, clusterState);
402+
response = newResponse(request, nodeResponseTracker, unavailableShardCount, nodeIds, clusterState);
403+
} catch (NodeResponseTracker.DiscardedResponsesException e) {
404+
// We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
405+
// follow-up actions
406+
listener.onFailure((Exception) e.getCause());
408407
} catch (Exception e) {
409408
logger.debug("failed to combine responses from nodes", e);
410409
listener.onFailure(e);
@@ -417,6 +416,21 @@ protected void onCompletion() {
417416
}
418417
}
419418
}
419+
420+
@Override
421+
public void onCancelled() {
422+
assert task instanceof CancellableTask : "task must be cancellable";
423+
try {
424+
((CancellableTask) task).ensureNotCancelled();
425+
} catch (TaskCancelledException e) {
426+
nodeResponseTracker.discardIntermediateResponses(e);
427+
}
428+
}
429+
430+
// For testing purposes
431+
public NodeResponseTracker getNodeResponseTracker() {
432+
return nodeResponseTracker;
433+
}
420434
}
421435

422436
class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler<NodeRequest> {

0 commit comments

Comments
 (0)