16
16
import org .elasticsearch .action .support .DefaultShardOperationFailedException ;
17
17
import org .elasticsearch .action .support .HandledTransportAction ;
18
18
import org .elasticsearch .action .support .IndicesOptions ;
19
+ import org .elasticsearch .action .support .NodeResponseTracker ;
19
20
import org .elasticsearch .action .support .TransportActions ;
20
21
import org .elasticsearch .action .support .broadcast .BroadcastRequest ;
21
22
import org .elasticsearch .action .support .broadcast .BroadcastResponse ;
51
52
import java .util .List ;
52
53
import java .util .Map ;
53
54
import java .util .concurrent .atomic .AtomicInteger ;
54
- import java .util .concurrent .atomic .AtomicReferenceArray ;
55
55
import java .util .function .Consumer ;
56
56
57
57
/**
@@ -118,28 +118,29 @@ public TransportBroadcastByNodeAction(
118
118
119
119
private Response newResponse (
120
120
Request request ,
121
- AtomicReferenceArray <?> responses ,
121
+ NodeResponseTracker nodeResponseTracker ,
122
122
int unavailableShardCount ,
123
123
Map <String , List <ShardRouting >> nodes ,
124
124
ClusterState clusterState
125
- ) {
125
+ ) throws NodeResponseTracker . DiscardedResponsesException {
126
126
int totalShards = 0 ;
127
127
int successfulShards = 0 ;
128
128
List <ShardOperationResult > broadcastByNodeResponses = new ArrayList <>();
129
129
List <DefaultShardOperationFailedException > exceptions = new ArrayList <>();
130
- for (int i = 0 ; i < responses .length (); i ++) {
131
- if (responses .get (i )instanceof FailedNodeException exception ) {
130
+ for (int i = 0 ; i < nodeResponseTracker .getExpectedResponseCount (); i ++) {
131
+ Object response = nodeResponseTracker .getResponse (i );
132
+ if (response instanceof FailedNodeException exception ) {
132
133
totalShards += nodes .get (exception .nodeId ()).size ();
133
134
for (ShardRouting shard : nodes .get (exception .nodeId ())) {
134
135
exceptions .add (new DefaultShardOperationFailedException (shard .getIndexName (), shard .getId (), exception ));
135
136
}
136
137
} else {
137
138
@ SuppressWarnings ("unchecked" )
138
- NodeResponse response = (NodeResponse ) responses . get ( i ) ;
139
- broadcastByNodeResponses .addAll (response .results );
140
- totalShards += response .getTotalShards ();
141
- successfulShards += response .getSuccessfulShards ();
142
- for (BroadcastShardOperationFailedException throwable : response .getExceptions ()) {
139
+ NodeResponse nodeResponse = (NodeResponse ) response ;
140
+ broadcastByNodeResponses .addAll (nodeResponse .results );
141
+ totalShards += nodeResponse .getTotalShards ();
142
+ successfulShards += nodeResponse .getSuccessfulShards ();
143
+ for (BroadcastShardOperationFailedException throwable : nodeResponse .getExceptions ()) {
143
144
if (TransportActions .isShardNotAvailableException (throwable ) == false ) {
144
145
exceptions .add (
145
146
new DefaultShardOperationFailedException (
@@ -256,16 +257,15 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
256
257
new AsyncAction (task , request , listener ).start ();
257
258
}
258
259
259
- protected class AsyncAction {
260
+ protected class AsyncAction implements CancellableTask . CancellationListener {
260
261
private final Task task ;
261
262
private final Request request ;
262
263
private final ActionListener <Response > listener ;
263
264
private final ClusterState clusterState ;
264
265
private final DiscoveryNodes nodes ;
265
266
private final Map <String , List <ShardRouting >> nodeIds ;
266
- private final AtomicReferenceArray <Object > responses ;
267
- private final AtomicInteger counter = new AtomicInteger ();
268
267
private final int unavailableShardCount ;
268
+ private final NodeResponseTracker nodeResponseTracker ;
269
269
270
270
protected AsyncAction (Task task , Request request , ActionListener <Response > listener ) {
271
271
this .task = task ;
@@ -312,10 +312,13 @@ protected AsyncAction(Task task, Request request, ActionListener<Response> liste
312
312
313
313
}
314
314
this .unavailableShardCount = unavailableShardCount ;
315
- responses = new AtomicReferenceArray <> (nodeIds .size ());
315
+ nodeResponseTracker = new NodeResponseTracker (nodeIds .size ());
316
316
}
317
317
318
318
public void start () {
319
+ if (task instanceof CancellableTask cancellableTask ) {
320
+ cancellableTask .addListener (this );
321
+ }
319
322
if (nodeIds .size () == 0 ) {
320
323
try {
321
324
onCompletion ();
@@ -373,38 +376,34 @@ protected void onNodeResponse(DiscoveryNode node, int nodeIndex, NodeResponse re
373
376
logger .trace ("received response for [{}] from node [{}]" , actionName , node .getId ());
374
377
}
375
378
376
- // this is defensive to protect against the possibility of double invocation
377
- // the current implementation of TransportService#sendRequest guards against this
378
- // but concurrency is hard, safety is important, and the small performance loss here does not matter
379
- if (responses .compareAndSet (nodeIndex , null , response )) {
380
- if (counter .incrementAndGet () == responses .length ()) {
381
- onCompletion ();
382
- }
379
+ if (nodeResponseTracker .trackResponseAndCheckIfLast (nodeIndex , response )) {
380
+ onCompletion ();
383
381
}
384
382
}
385
383
386
384
protected void onNodeFailure (DiscoveryNode node , int nodeIndex , Throwable t ) {
387
385
String nodeId = node .getId ();
388
386
logger .debug (new ParameterizedMessage ("failed to execute [{}] on node [{}]" , actionName , nodeId ), t );
389
-
390
- // this is defensive to protect against the possibility of double invocation
391
- // the current implementation of TransportService#sendRequest guards against this
392
- // but concurrency is hard, safety is important, and the small performance loss here does not matter
393
- if (responses .compareAndSet (nodeIndex , null , new FailedNodeException (nodeId , "Failed node [" + nodeId + "]" , t ))) {
394
- if (counter .incrementAndGet () == responses .length ()) {
395
- onCompletion ();
396
- }
387
+ if (nodeResponseTracker .trackResponseAndCheckIfLast (
388
+ nodeIndex ,
389
+ new FailedNodeException (nodeId , "Failed node [" + nodeId + "]" , t )
390
+ )) {
391
+ onCompletion ();
397
392
}
398
393
}
399
394
400
395
protected void onCompletion () {
401
- if (task instanceof CancellableTask && (( CancellableTask ) task ) .notifyIfCancelled (listener )) {
396
+ if (( task instanceof CancellableTask t ) && t .notifyIfCancelled (listener )) {
402
397
return ;
403
398
}
404
399
405
400
Response response = null ;
406
401
try {
407
- response = newResponse (request , responses , unavailableShardCount , nodeIds , clusterState );
402
+ response = newResponse (request , nodeResponseTracker , unavailableShardCount , nodeIds , clusterState );
403
+ } catch (NodeResponseTracker .DiscardedResponsesException e ) {
404
+ // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
405
+ // follow-up actions
406
+ listener .onFailure ((Exception ) e .getCause ());
408
407
} catch (Exception e ) {
409
408
logger .debug ("failed to combine responses from nodes" , e );
410
409
listener .onFailure (e );
@@ -417,6 +416,21 @@ protected void onCompletion() {
417
416
}
418
417
}
419
418
}
419
+
420
+ @ Override
421
+ public void onCancelled () {
422
+ assert task instanceof CancellableTask : "task must be cancellable" ;
423
+ try {
424
+ ((CancellableTask ) task ).ensureNotCancelled ();
425
+ } catch (TaskCancelledException e ) {
426
+ nodeResponseTracker .discardIntermediateResponses (e );
427
+ }
428
+ }
429
+
430
+ // For testing purposes
431
+ public NodeResponseTracker getNodeResponseTracker () {
432
+ return nodeResponseTracker ;
433
+ }
420
434
}
421
435
422
436
class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler <NodeRequest > {
0 commit comments