47
47
import org .elasticsearch .common .util .concurrent .EsRejectedExecutionException ;
48
48
import org .elasticsearch .common .util .concurrent .FutureUtils ;
49
49
import org .elasticsearch .common .util .concurrent .PrioritizedEsThreadPoolExecutor ;
50
+ import org .elasticsearch .common .util .concurrent .ThreadContext ;
50
51
import org .elasticsearch .discovery .Discovery ;
51
52
import org .elasticsearch .threadpool .ThreadPool ;
52
53
59
60
import java .util .concurrent .Future ;
60
61
import java .util .concurrent .TimeUnit ;
61
62
import java .util .function .BiConsumer ;
63
+ import java .util .function .Supplier ;
62
64
import java .util .stream .Collectors ;
63
65
64
66
import static org .elasticsearch .cluster .service .ClusterService .CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING ;
@@ -426,26 +428,28 @@ public TimeValue getMaxTaskWaitTime() {
426
428
return threadPoolExecutor .getMaxTaskWaitTime ();
427
429
}
428
430
429
- private SafeClusterStateTaskListener safe (ClusterStateTaskListener listener ) {
431
+ private SafeClusterStateTaskListener safe (ClusterStateTaskListener listener , Supplier < ThreadContext . StoredContext > contextSupplier ) {
430
432
if (listener instanceof AckedClusterStateTaskListener ) {
431
- return new SafeAckedClusterStateTaskListener ((AckedClusterStateTaskListener ) listener , logger );
433
+ return new SafeAckedClusterStateTaskListener ((AckedClusterStateTaskListener ) listener , contextSupplier , logger );
432
434
} else {
433
- return new SafeClusterStateTaskListener (listener , logger );
435
+ return new SafeClusterStateTaskListener (listener , contextSupplier , logger );
434
436
}
435
437
}
436
438
437
439
private static class SafeClusterStateTaskListener implements ClusterStateTaskListener {
438
440
private final ClusterStateTaskListener listener ;
441
+ protected final Supplier <ThreadContext .StoredContext > context ;
439
442
private final Logger logger ;
440
443
441
- SafeClusterStateTaskListener (ClusterStateTaskListener listener , Logger logger ) {
444
+ SafeClusterStateTaskListener (ClusterStateTaskListener listener , Supplier < ThreadContext . StoredContext > context , Logger logger ) {
442
445
this .listener = listener ;
446
+ this .context = context ;
443
447
this .logger = logger ;
444
448
}
445
449
446
450
@ Override
447
451
public void onFailure (String source , Exception e ) {
448
- try {
452
+ try ( ThreadContext . StoredContext ignore = context . get ()) {
449
453
listener .onFailure (source , e );
450
454
} catch (Exception inner ) {
451
455
inner .addSuppressed (e );
@@ -456,7 +460,7 @@ public void onFailure(String source, Exception e) {
456
460
457
461
@ Override
458
462
public void onNoLongerMaster (String source ) {
459
- try {
463
+ try ( ThreadContext . StoredContext ignore = context . get ()) {
460
464
listener .onNoLongerMaster (source );
461
465
} catch (Exception e ) {
462
466
logger .error (() -> new ParameterizedMessage (
@@ -466,7 +470,7 @@ public void onNoLongerMaster(String source) {
466
470
467
471
@ Override
468
472
public void clusterStateProcessed (String source , ClusterState oldState , ClusterState newState ) {
469
- try {
473
+ try ( ThreadContext . StoredContext ignore = context . get ()) {
470
474
listener .clusterStateProcessed (source , oldState , newState );
471
475
} catch (Exception e ) {
472
476
logger .error (() -> new ParameterizedMessage (
@@ -480,8 +484,9 @@ private static class SafeAckedClusterStateTaskListener extends SafeClusterStateT
480
484
private final AckedClusterStateTaskListener listener ;
481
485
private final Logger logger ;
482
486
483
- SafeAckedClusterStateTaskListener (AckedClusterStateTaskListener listener , Logger logger ) {
484
- super (listener , logger );
487
+ SafeAckedClusterStateTaskListener (AckedClusterStateTaskListener listener , Supplier <ThreadContext .StoredContext > context ,
488
+ Logger logger ) {
489
+ super (listener , context , logger );
485
490
this .listener = listener ;
486
491
this .logger = logger ;
487
492
}
@@ -493,7 +498,7 @@ public boolean mustAck(DiscoveryNode discoveryNode) {
493
498
494
499
@ Override
495
500
public void onAllNodesAcked (@ Nullable Exception e ) {
496
- try {
501
+ try ( ThreadContext . StoredContext ignore = context . get ()) {
497
502
listener .onAllNodesAcked (e );
498
503
} catch (Exception inner ) {
499
504
inner .addSuppressed (e );
@@ -503,7 +508,7 @@ public void onAllNodesAcked(@Nullable Exception e) {
503
508
504
509
@ Override
505
510
public void onAckTimeout () {
506
- try {
511
+ try ( ThreadContext . StoredContext ignore = context . get ()) {
507
512
listener .onAckTimeout ();
508
513
} catch (Exception e ) {
509
514
logger .error ("exception thrown by listener while notifying on ack timeout" , e );
@@ -724,9 +729,13 @@ public <T> void submitStateUpdateTasks(final String source,
724
729
if (!lifecycle .started ()) {
725
730
return ;
726
731
}
727
- try {
732
+ final ThreadContext threadContext = threadPool .getThreadContext ();
733
+ final Supplier <ThreadContext .StoredContext > supplier = threadContext .newRestorableContext (false );
734
+ try (ThreadContext .StoredContext ignore = threadContext .stashContext ()) {
735
+ threadContext .markAsSystemContext ();
736
+
728
737
List <Batcher .UpdateTask > safeTasks = tasks .entrySet ().stream ()
729
- .map (e -> taskBatcher .new UpdateTask (config .priority (), source , e .getKey (), safe (e .getValue ()), executor ))
738
+ .map (e -> taskBatcher .new UpdateTask (config .priority (), source , e .getKey (), safe (e .getValue (), supplier ), executor ))
730
739
.collect (Collectors .toList ());
731
740
taskBatcher .submitTasks (safeTasks , config .timeout ());
732
741
} catch (EsRejectedExecutionException e ) {
0 commit comments