16
16
import org .elasticsearch .action .support .DefaultShardOperationFailedException ;
17
17
import org .elasticsearch .action .support .HandledTransportAction ;
18
18
import org .elasticsearch .action .support .IndicesOptions ;
19
+ import org .elasticsearch .action .support .NodeResponseTracker ;
19
20
import org .elasticsearch .action .support .TransportActions ;
20
21
import org .elasticsearch .action .support .broadcast .BroadcastRequest ;
21
22
import org .elasticsearch .action .support .broadcast .BroadcastResponse ;
51
52
import java .util .List ;
52
53
import java .util .Map ;
53
54
import java .util .concurrent .atomic .AtomicInteger ;
54
- import java .util .concurrent .atomic .AtomicReferenceArray ;
55
55
import java .util .function .Consumer ;
56
56
57
57
/**
@@ -118,29 +118,30 @@ public TransportBroadcastByNodeAction(
118
118
119
119
private Response newResponse (
120
120
Request request ,
121
- AtomicReferenceArray <?> responses ,
121
+ NodeResponseTracker nodeResponseTracker ,
122
122
int unavailableShardCount ,
123
123
Map <String , List <ShardRouting >> nodes ,
124
124
ClusterState clusterState
125
- ) {
125
+ ) throws NodeResponseTracker . DiscardedResponsesException {
126
126
int totalShards = 0 ;
127
127
int successfulShards = 0 ;
128
128
List <ShardOperationResult > broadcastByNodeResponses = new ArrayList <>();
129
129
List <DefaultShardOperationFailedException > exceptions = new ArrayList <>();
130
- for (int i = 0 ; i < responses .length (); i ++) {
131
- if (responses .get (i ) instanceof FailedNodeException ) {
132
- FailedNodeException exception = (FailedNodeException ) responses .get (i );
130
+ for (int i = 0 ; i < nodeResponseTracker .getExpectedResponseCount (); i ++) {
131
+ Object response = nodeResponseTracker .getResponse (i );
132
+ if (response instanceof FailedNodeException ) {
133
+ FailedNodeException exception = (FailedNodeException ) response ;
133
134
totalShards += nodes .get (exception .nodeId ()).size ();
134
135
for (ShardRouting shard : nodes .get (exception .nodeId ())) {
135
136
exceptions .add (new DefaultShardOperationFailedException (shard .getIndexName (), shard .getId (), exception ));
136
137
}
137
138
} else {
138
139
@ SuppressWarnings ("unchecked" )
139
- NodeResponse response = (NodeResponse ) responses . get ( i ) ;
140
- broadcastByNodeResponses .addAll (response .results );
141
- totalShards += response .getTotalShards ();
142
- successfulShards += response .getSuccessfulShards ();
143
- for (BroadcastShardOperationFailedException throwable : response .getExceptions ()) {
140
+ NodeResponse nodeResponse = (NodeResponse ) response ;
141
+ broadcastByNodeResponses .addAll (nodeResponse .results );
142
+ totalShards += nodeResponse .getTotalShards ();
143
+ successfulShards += nodeResponse .getSuccessfulShards ();
144
+ for (BroadcastShardOperationFailedException throwable : nodeResponse .getExceptions ()) {
144
145
if (TransportActions .isShardNotAvailableException (throwable ) == false ) {
145
146
exceptions .add (
146
147
new DefaultShardOperationFailedException (
@@ -257,16 +258,15 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
257
258
new AsyncAction (task , request , listener ).start ();
258
259
}
259
260
260
- protected class AsyncAction {
261
+ protected class AsyncAction implements CancellableTask . CancellationListener {
261
262
private final Task task ;
262
263
private final Request request ;
263
264
private final ActionListener <Response > listener ;
264
265
private final ClusterState clusterState ;
265
266
private final DiscoveryNodes nodes ;
266
267
private final Map <String , List <ShardRouting >> nodeIds ;
267
- private final AtomicReferenceArray <Object > responses ;
268
- private final AtomicInteger counter = new AtomicInteger ();
269
268
private final int unavailableShardCount ;
269
+ private final NodeResponseTracker nodeResponseTracker ;
270
270
271
271
protected AsyncAction (Task task , Request request , ActionListener <Response > listener ) {
272
272
this .task = task ;
@@ -313,10 +313,14 @@ protected AsyncAction(Task task, Request request, ActionListener<Response> liste
313
313
314
314
}
315
315
this .unavailableShardCount = unavailableShardCount ;
316
- responses = new AtomicReferenceArray <> (nodeIds .size ());
316
+ nodeResponseTracker = new NodeResponseTracker (nodeIds .size ());
317
317
}
318
318
319
319
public void start () {
320
+ if (task instanceof CancellableTask ) {
321
+ CancellableTask cancellableTask = (CancellableTask ) task ;
322
+ cancellableTask .addListener (this );
323
+ }
320
324
if (nodeIds .size () == 0 ) {
321
325
try {
322
326
onCompletion ();
@@ -374,38 +378,37 @@ protected void onNodeResponse(DiscoveryNode node, int nodeIndex, NodeResponse re
374
378
logger .trace ("received response for [{}] from node [{}]" , actionName , node .getId ());
375
379
}
376
380
377
- // this is defensive to protect against the possibility of double invocation
378
- // the current implementation of TransportService#sendRequest guards against this
379
- // but concurrency is hard, safety is important, and the small performance loss here does not matter
380
- if (responses .compareAndSet (nodeIndex , null , response )) {
381
- if (counter .incrementAndGet () == responses .length ()) {
382
- onCompletion ();
383
- }
381
+ if (nodeResponseTracker .trackResponseAndCheckIfLast (nodeIndex , response )) {
382
+ onCompletion ();
384
383
}
385
384
}
386
385
387
386
protected void onNodeFailure (DiscoveryNode node , int nodeIndex , Throwable t ) {
388
387
String nodeId = node .getId ();
389
388
logger .debug (new ParameterizedMessage ("failed to execute [{}] on node [{}]" , actionName , nodeId ), t );
390
-
391
- // this is defensive to protect against the possibility of double invocation
392
- // the current implementation of TransportService#sendRequest guards against this
393
- // but concurrency is hard, safety is important, and the small performance loss here does not matter
394
- if (responses .compareAndSet (nodeIndex , null , new FailedNodeException (nodeId , "Failed node [" + nodeId + "]" , t ))) {
395
- if (counter .incrementAndGet () == responses .length ()) {
396
- onCompletion ();
397
- }
389
+ if (nodeResponseTracker .trackResponseAndCheckIfLast (
390
+ nodeIndex ,
391
+ new FailedNodeException (nodeId , "Failed node [" + nodeId + "]" , t )
392
+ )) {
393
+ onCompletion ();
398
394
}
399
395
}
400
396
401
397
protected void onCompletion () {
402
- if (task instanceof CancellableTask && ((CancellableTask ) task ).notifyIfCancelled (listener )) {
403
- return ;
398
+ if (task instanceof CancellableTask ) {
399
+ CancellableTask cancellableTask = (CancellableTask ) task ;
400
+ if (cancellableTask .notifyIfCancelled (listener )) {
401
+ return ;
402
+ }
404
403
}
405
404
406
405
Response response = null ;
407
406
try {
408
- response = newResponse (request , responses , unavailableShardCount , nodeIds , clusterState );
407
+ response = newResponse (request , nodeResponseTracker , unavailableShardCount , nodeIds , clusterState );
408
+ } catch (NodeResponseTracker .DiscardedResponsesException e ) {
409
+ // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
410
+ // follow-up actions
411
+ listener .onFailure ((Exception ) e .getCause ());
409
412
} catch (Exception e ) {
410
413
logger .debug ("failed to combine responses from nodes" , e );
411
414
listener .onFailure (e );
@@ -418,6 +421,21 @@ protected void onCompletion() {
418
421
}
419
422
}
420
423
}
424
+
425
+ @ Override
426
+ public void onCancelled () {
427
+ assert task instanceof CancellableTask : "task must be cancellable" ;
428
+ try {
429
+ ((CancellableTask ) task ).ensureNotCancelled ();
430
+ } catch (TaskCancelledException e ) {
431
+ nodeResponseTracker .discardIntermediateResponses (e );
432
+ }
433
+ }
434
+
435
+ // For testing purposes
436
+ public NodeResponseTracker getNodeResponseTracker () {
437
+ return nodeResponseTracker ;
438
+ }
421
439
}
422
440
423
441
class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler <NodeRequest > {
0 commit comments