9
9
10
10
import org .apache .logging .log4j .LogManager ;
11
11
import org .apache .logging .log4j .Logger ;
12
+ import org .elasticsearch .ElasticsearchTimeoutException ;
12
13
import org .elasticsearch .ExceptionsHelper ;
13
14
import org .elasticsearch .TransportVersions ;
14
15
import org .elasticsearch .Version ;
22
23
import org .elasticsearch .action .support .ActionFilters ;
23
24
import org .elasticsearch .action .support .HandledTransportAction ;
24
25
import org .elasticsearch .action .support .RefCountingRunnable ;
26
+ import org .elasticsearch .action .support .SubscribableListener ;
25
27
import org .elasticsearch .cluster .ClusterState ;
26
28
import org .elasticsearch .cluster .node .DiscoveryNode ;
27
29
import org .elasticsearch .cluster .node .DiscoveryNodes ;
@@ -364,6 +366,7 @@ public static class AsyncAction {
364
366
private final DiscoveryNodes discoveryNodes ;
365
367
private final LongSupplier currentTimeMillisSupplier ;
366
368
private final ActionListener <Response > listener ;
369
+ private final SubscribableListener <Void > cancellationListener ;
367
370
private final long timeoutTimeMillis ;
368
371
369
372
// choose the blob path nondeterministically to avoid clashes, assuming that the actual path doesn't matter for reproduction
@@ -394,15 +397,24 @@ public AsyncAction(
394
397
this .discoveryNodes = discoveryNodes ;
395
398
this .currentTimeMillisSupplier = currentTimeMillisSupplier ;
396
399
this .timeoutTimeMillis = currentTimeMillisSupplier .getAsLong () + request .getTimeout ().millis ();
397
- this .listener = listener ;
400
+
401
+ this .cancellationListener = new SubscribableListener <>();
402
+ this .listener = ActionListener .runBefore (listener , () -> cancellationListener .onResponse (null ));
398
403
399
404
responses = new ArrayList <>(request .blobCount );
400
405
}
401
406
402
- private void fail (Exception e ) {
407
+ private boolean setFirstFailure (Exception e ) {
403
408
if (failure .compareAndSet (null , e )) {
404
409
transportService .getTaskManager ().cancelTaskAndDescendants (task , "task failed" , false , ActionListener .noop ());
410
+ return true ;
405
411
} else {
412
+ return false ;
413
+ }
414
+ }
415
+
416
+ private void fail (Exception e ) {
417
+ if (setFirstFailure (e ) == false ) {
406
418
if (innerFailures .tryAcquire ()) {
407
419
final Throwable cause = ExceptionsHelper .unwrapCause (e );
408
420
if (cause instanceof TaskCancelledException || cause instanceof ReceiveTimeoutTransportException ) {
@@ -424,24 +436,34 @@ private boolean isRunning() {
424
436
}
425
437
426
438
if (task .isCancelled ()) {
427
- failure . compareAndSet ( null , new RepositoryVerificationException (request .repositoryName , "verification cancelled" ));
439
+ setFirstFailure ( new RepositoryVerificationException (request .repositoryName , "verification cancelled" ));
428
440
// if this CAS failed then we're failing for some other reason, nbd; also if the task is cancelled then its descendants are
429
441
// also cancelled, so no further action is needed either way.
430
442
return false ;
431
443
}
432
444
433
- if (timeoutTimeMillis < currentTimeMillisSupplier .getAsLong ()) {
434
- if (failure .compareAndSet (
435
- null ,
436
- new RepositoryVerificationException (request .repositoryName , "analysis timed out after [" + request .getTimeout () + "]" )
437
- )) {
438
- transportService .getTaskManager ().cancelTaskAndDescendants (task , "timed out" , false , ActionListener .noop ());
439
- }
440
- // if this CAS failed then we're already failing for some other reason, nbd
441
- return false ;
445
+ return true ;
446
+ }
447
+
448
+ private class CheckForCancelListener implements ActionListener <Void > {
449
+ @ Override
450
+ public void onResponse (Void unused ) {
451
+ // task complete, nothing to do
442
452
}
443
453
444
- return true ;
454
+ @ Override
455
+ public void onFailure (Exception e ) {
456
+ assert e instanceof ElasticsearchTimeoutException : e ;
457
+ if (isRunning ()) {
458
+ // if this CAS fails then we're already failing for some other reason, nbd
459
+ setFirstFailure (
460
+ new RepositoryVerificationException (
461
+ request .repositoryName ,
462
+ "analysis timed out after [" + request .getTimeout () + "]"
463
+ )
464
+ );
465
+ }
466
+ }
445
467
}
446
468
447
469
public void run () {
@@ -450,6 +472,9 @@ public void run() {
450
472
451
473
logger .info ("running analysis of repository [{}] using path [{}]" , request .getRepositoryName (), blobPath );
452
474
475
+ cancellationListener .addTimeout (request .getTimeout (), repository .threadPool (), EsExecutors .DIRECT_EXECUTOR_SERVICE );
476
+ cancellationListener .addListener (new CheckForCancelListener ());
477
+
453
478
final Random random = new Random (request .getSeed ());
454
479
final List <DiscoveryNode > nodes = getSnapshotNodes (discoveryNodes );
455
480
@@ -536,7 +561,7 @@ private void runBlobAnalysis(Releasable ref, final BlobAnalyzeAction.Request req
536
561
BlobAnalyzeAction .NAME ,
537
562
request ,
538
563
task ,
539
- TransportRequestOptions .timeout ( TimeValue . timeValueMillis ( timeoutTimeMillis - currentTimeMillisSupplier . getAsLong ())) ,
564
+ TransportRequestOptions .EMPTY ,
540
565
new ActionListenerResponseHandler <>(ActionListener .releaseAfter (new ActionListener <>() {
541
566
@ Override
542
567
public void onResponse (BlobAnalyzeAction .Response response ) {
0 commit comments