Skip to content

Commit bdffb05

Browse files
authored
Discard intermediate node results when a request is cancelled (elastic#88851)
* Discard intermediate node results when a request is cancelled * Update docs/changelog/88851.yaml * Refer to issue in the change log * Remove the changelog of the original PR * Fix formatting * Delete docs/changelog/88851.yaml * Revert "Remove the changelog of the original PR" This reverts commit d8c8f07.
1 parent 807a9d8 commit bdffb05

File tree

9 files changed

+342
-66
lines changed

9 files changed

+342
-66
lines changed

docs/changelog/82685.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 82685
2+
summary: Discard intermediate results upon cancellation for stats endpoints
3+
area: Stats
4+
type: bug
5+
issues:
6+
- 82337
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.action.support;
10+
11+
import java.util.Collection;
12+
import java.util.concurrent.atomic.AtomicInteger;
13+
import java.util.concurrent.atomic.AtomicReferenceArray;
14+
15+
/**
16+
* This class tracks the intermediate responses that will be used to create aggregated cluster response to a request. It also gives the
17+
* possibility to discard the intermediate results when asked, for example when the initial request is cancelled, in order to release the
18+
* resources.
19+
*/
20+
public class NodeResponseTracker {
21+
22+
private final AtomicInteger counter = new AtomicInteger();
23+
private final int expectedResponsesCount;
24+
private volatile AtomicReferenceArray<Object> responses;
25+
private volatile Exception causeOfDiscarding;
26+
27+
public NodeResponseTracker(int size) {
28+
this.expectedResponsesCount = size;
29+
this.responses = new AtomicReferenceArray<>(size);
30+
}
31+
32+
public NodeResponseTracker(Collection<Object> array) {
33+
this.expectedResponsesCount = array.size();
34+
this.responses = new AtomicReferenceArray<>(array.toArray());
35+
}
36+
37+
/**
38+
* This method discards the results collected so far to free up the resources.
39+
* @param cause the discarding, this will be communicated if they try to access the discarded results
40+
*/
41+
public void discardIntermediateResponses(Exception cause) {
42+
if (responses != null) {
43+
this.causeOfDiscarding = cause;
44+
responses = null;
45+
}
46+
}
47+
48+
public boolean responsesDiscarded() {
49+
return responses == null;
50+
}
51+
52+
/**
53+
* This method stores a new node response if the intermediate responses haven't been discarded yet. If the responses are not discarded
54+
* the method asserts that this is the first response encountered from this node to protect from miscounting the responses in case of a
55+
* double invocation. If the responses have been discarded we accept this risk for simplicity.
56+
* @param nodeIndex, the index that represents a single node of the cluster
57+
* @param response, a response can be either a NodeResponse or an error
58+
* @return true if all the nodes' responses have been received, else false
59+
*/
60+
public boolean trackResponseAndCheckIfLast(int nodeIndex, Object response) {
61+
AtomicReferenceArray<Object> responses = this.responses;
62+
63+
if (responsesDiscarded() == false) {
64+
boolean firstEncounter = responses.compareAndSet(nodeIndex, null, response);
65+
assert firstEncounter : "a response should be tracked only once";
66+
}
67+
return counter.incrementAndGet() == getExpectedResponseCount();
68+
}
69+
70+
/**
71+
* Returns the tracked response or null if the response hasn't been received yet for a specific index that represents a node of the
72+
* cluster.
73+
* @throws DiscardedResponsesException if the responses have been discarded
74+
*/
75+
public Object getResponse(int nodeIndex) throws DiscardedResponsesException {
76+
AtomicReferenceArray<Object> responses = this.responses;
77+
if (responsesDiscarded()) {
78+
throw new DiscardedResponsesException(causeOfDiscarding);
79+
}
80+
return responses.get(nodeIndex);
81+
}
82+
83+
public int getExpectedResponseCount() {
84+
return expectedResponsesCount;
85+
}
86+
87+
/**
88+
* This exception is thrown when the {@link NodeResponseTracker} is asked to give information about the responses after they have been
89+
* discarded.
90+
*/
91+
public static class DiscardedResponsesException extends Exception {
92+
93+
public DiscardedResponsesException(Exception cause) {
94+
super(cause);
95+
}
96+
}
97+
}

server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java

Lines changed: 51 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
1717
import org.elasticsearch.action.support.HandledTransportAction;
1818
import org.elasticsearch.action.support.IndicesOptions;
19+
import org.elasticsearch.action.support.NodeResponseTracker;
1920
import org.elasticsearch.action.support.TransportActions;
2021
import org.elasticsearch.action.support.broadcast.BroadcastRequest;
2122
import org.elasticsearch.action.support.broadcast.BroadcastResponse;
@@ -51,7 +52,6 @@
5152
import java.util.List;
5253
import java.util.Map;
5354
import java.util.concurrent.atomic.AtomicInteger;
54-
import java.util.concurrent.atomic.AtomicReferenceArray;
5555
import java.util.function.Consumer;
5656

5757
/**
@@ -118,29 +118,30 @@ public TransportBroadcastByNodeAction(
118118

119119
private Response newResponse(
120120
Request request,
121-
AtomicReferenceArray<?> responses,
121+
NodeResponseTracker nodeResponseTracker,
122122
int unavailableShardCount,
123123
Map<String, List<ShardRouting>> nodes,
124124
ClusterState clusterState
125-
) {
125+
) throws NodeResponseTracker.DiscardedResponsesException {
126126
int totalShards = 0;
127127
int successfulShards = 0;
128128
List<ShardOperationResult> broadcastByNodeResponses = new ArrayList<>();
129129
List<DefaultShardOperationFailedException> exceptions = new ArrayList<>();
130-
for (int i = 0; i < responses.length(); i++) {
131-
if (responses.get(i) instanceof FailedNodeException) {
132-
FailedNodeException exception = (FailedNodeException) responses.get(i);
130+
for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); i++) {
131+
Object response = nodeResponseTracker.getResponse(i);
132+
if (response instanceof FailedNodeException) {
133+
FailedNodeException exception = (FailedNodeException) response;
133134
totalShards += nodes.get(exception.nodeId()).size();
134135
for (ShardRouting shard : nodes.get(exception.nodeId())) {
135136
exceptions.add(new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), exception));
136137
}
137138
} else {
138139
@SuppressWarnings("unchecked")
139-
NodeResponse response = (NodeResponse) responses.get(i);
140-
broadcastByNodeResponses.addAll(response.results);
141-
totalShards += response.getTotalShards();
142-
successfulShards += response.getSuccessfulShards();
143-
for (BroadcastShardOperationFailedException throwable : response.getExceptions()) {
140+
NodeResponse nodeResponse = (NodeResponse) response;
141+
broadcastByNodeResponses.addAll(nodeResponse.results);
142+
totalShards += nodeResponse.getTotalShards();
143+
successfulShards += nodeResponse.getSuccessfulShards();
144+
for (BroadcastShardOperationFailedException throwable : nodeResponse.getExceptions()) {
144145
if (TransportActions.isShardNotAvailableException(throwable) == false) {
145146
exceptions.add(
146147
new DefaultShardOperationFailedException(
@@ -257,16 +258,15 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
257258
new AsyncAction(task, request, listener).start();
258259
}
259260

260-
protected class AsyncAction {
261+
protected class AsyncAction implements CancellableTask.CancellationListener {
261262
private final Task task;
262263
private final Request request;
263264
private final ActionListener<Response> listener;
264265
private final ClusterState clusterState;
265266
private final DiscoveryNodes nodes;
266267
private final Map<String, List<ShardRouting>> nodeIds;
267-
private final AtomicReferenceArray<Object> responses;
268-
private final AtomicInteger counter = new AtomicInteger();
269268
private final int unavailableShardCount;
269+
private final NodeResponseTracker nodeResponseTracker;
270270

271271
protected AsyncAction(Task task, Request request, ActionListener<Response> listener) {
272272
this.task = task;
@@ -313,10 +313,14 @@ protected AsyncAction(Task task, Request request, ActionListener<Response> liste
313313

314314
}
315315
this.unavailableShardCount = unavailableShardCount;
316-
responses = new AtomicReferenceArray<>(nodeIds.size());
316+
nodeResponseTracker = new NodeResponseTracker(nodeIds.size());
317317
}
318318

319319
public void start() {
320+
if (task instanceof CancellableTask) {
321+
CancellableTask cancellableTask = (CancellableTask) task;
322+
cancellableTask.addListener(this);
323+
}
320324
if (nodeIds.size() == 0) {
321325
try {
322326
onCompletion();
@@ -374,38 +378,37 @@ protected void onNodeResponse(DiscoveryNode node, int nodeIndex, NodeResponse re
374378
logger.trace("received response for [{}] from node [{}]", actionName, node.getId());
375379
}
376380

377-
// this is defensive to protect against the possibility of double invocation
378-
// the current implementation of TransportService#sendRequest guards against this
379-
// but concurrency is hard, safety is important, and the small performance loss here does not matter
380-
if (responses.compareAndSet(nodeIndex, null, response)) {
381-
if (counter.incrementAndGet() == responses.length()) {
382-
onCompletion();
383-
}
381+
if (nodeResponseTracker.trackResponseAndCheckIfLast(nodeIndex, response)) {
382+
onCompletion();
384383
}
385384
}
386385

387386
protected void onNodeFailure(DiscoveryNode node, int nodeIndex, Throwable t) {
388387
String nodeId = node.getId();
389388
logger.debug(new ParameterizedMessage("failed to execute [{}] on node [{}]", actionName, nodeId), t);
390-
391-
// this is defensive to protect against the possibility of double invocation
392-
// the current implementation of TransportService#sendRequest guards against this
393-
// but concurrency is hard, safety is important, and the small performance loss here does not matter
394-
if (responses.compareAndSet(nodeIndex, null, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) {
395-
if (counter.incrementAndGet() == responses.length()) {
396-
onCompletion();
397-
}
389+
if (nodeResponseTracker.trackResponseAndCheckIfLast(
390+
nodeIndex,
391+
new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t)
392+
)) {
393+
onCompletion();
398394
}
399395
}
400396

401397
protected void onCompletion() {
402-
if (task instanceof CancellableTask && ((CancellableTask) task).notifyIfCancelled(listener)) {
403-
return;
398+
if (task instanceof CancellableTask) {
399+
CancellableTask cancellableTask = (CancellableTask) task;
400+
if (cancellableTask.notifyIfCancelled(listener)) {
401+
return;
402+
}
404403
}
405404

406405
Response response = null;
407406
try {
408-
response = newResponse(request, responses, unavailableShardCount, nodeIds, clusterState);
407+
response = newResponse(request, nodeResponseTracker, unavailableShardCount, nodeIds, clusterState);
408+
} catch (NodeResponseTracker.DiscardedResponsesException e) {
409+
// We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
410+
// follow-up actions
411+
listener.onFailure((Exception) e.getCause());
409412
} catch (Exception e) {
410413
logger.debug("failed to combine responses from nodes", e);
411414
listener.onFailure(e);
@@ -418,6 +421,21 @@ protected void onCompletion() {
418421
}
419422
}
420423
}
424+
425+
@Override
426+
public void onCancelled() {
427+
assert task instanceof CancellableTask : "task must be cancellable";
428+
try {
429+
((CancellableTask) task).ensureNotCancelled();
430+
} catch (TaskCancelledException e) {
431+
nodeResponseTracker.discardIntermediateResponses(e);
432+
}
433+
}
434+
435+
// For testing purposes
436+
public NodeResponseTracker getNodeResponseTracker() {
437+
return nodeResponseTracker;
438+
}
421439
}
422440

423441
class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler<NodeRequest> {

0 commit comments

Comments
 (0)