12
12
import org .elasticsearch .action .ActionListener ;
13
13
import org .elasticsearch .action .StepListener ;
14
14
import org .elasticsearch .action .support .PlainActionFuture ;
15
+ import org .elasticsearch .common .settings .Settings ;
15
16
import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
17
+ import org .elasticsearch .common .util .concurrent .ThreadContext ;
16
18
import org .elasticsearch .tasks .TaskCancelledException ;
17
19
import org .elasticsearch .test .ESTestCase ;
18
20
import org .elasticsearch .threadpool .TestThreadPool ;
@@ -193,7 +195,8 @@ public void testExceptionCompletesListenersButIsNotCached() {
193
195
public void testConcurrentRefreshesAndCancellation () throws InterruptedException {
194
196
final ThreadPool threadPool = new TestThreadPool ("test" );
195
197
try {
196
- final CancellableSingleObjectCache <String , String , Integer > testCache = new CancellableSingleObjectCache <>() {
198
+ final ThreadContext threadContext = threadPool .getThreadContext ();
199
+ final CancellableSingleObjectCache <String , String , Integer > testCache = new CancellableSingleObjectCache <>(threadContext ) {
197
200
@ Override
198
201
protected void refresh (
199
202
String s ,
@@ -219,6 +222,7 @@ protected String getKey(String s) {
219
222
final CountDownLatch startLatch = new CountDownLatch (1 );
220
223
final CountDownLatch finishLatch = new CountDownLatch (count );
221
224
final BlockingQueue <Runnable > queue = ConcurrentCollections .newBlockingQueue ();
225
+ final String contextHeader = "test-context-header" ;
222
226
223
227
for (int i = 0 ; i < count ; i ++) {
224
228
final boolean cancel = randomBoolean ();
@@ -233,11 +237,14 @@ protected String getKey(String s) {
233
237
final StepListener <Integer > stepListener = new StepListener <>();
234
238
final AtomicBoolean isComplete = new AtomicBoolean ();
235
239
final AtomicBoolean isCancelled = new AtomicBoolean ();
236
- testCache .get (
237
- input ,
238
- isCancelled ::get ,
239
- ActionListener .runBefore (stepListener , () -> assertTrue (isComplete .compareAndSet (false , true )))
240
- );
240
+ try (ThreadContext .StoredContext ignored = threadContext .stashContext ()) {
241
+ final String contextValue = randomAlphaOfLength (10 );
242
+ threadContext .putHeader (contextHeader , contextValue );
243
+ testCache .get (input , isCancelled ::get , ActionListener .runBefore (stepListener , () -> {
244
+ assertTrue (isComplete .compareAndSet (false , true ));
245
+ assertThat (threadContext .getHeader (contextHeader ), equalTo (contextValue ));
246
+ }));
247
+ }
241
248
242
249
final Runnable next = queue .poll ();
243
250
if (next != null ) {
@@ -277,7 +284,9 @@ protected String getKey(String s) {
277
284
public void testConcurrentRefreshesWithFreshnessCheck () throws InterruptedException {
278
285
final ThreadPool threadPool = new TestThreadPool ("test" );
279
286
try {
280
- final CancellableSingleObjectCache <String , String , Integer > testCache = new CancellableSingleObjectCache <>() {
287
+ final CancellableSingleObjectCache <String , String , Integer > testCache = new CancellableSingleObjectCache <>(
288
+ threadPool .getThreadContext ()
289
+ ) {
281
290
@ Override
282
291
protected void refresh (
283
292
String s ,
@@ -380,7 +389,7 @@ public void run() {
380
389
}
381
390
};
382
391
383
- final CancellableSingleObjectCache <String , String , Integer > testCache = new CancellableSingleObjectCache <>() {
392
+ final CancellableSingleObjectCache <String , String , Integer > testCache = new CancellableSingleObjectCache <>(testThreadContext ) {
384
393
@ Override
385
394
protected void refresh (
386
395
String s ,
@@ -424,10 +433,16 @@ protected String getKey(String s) {
424
433
expectThrows (TaskCancelledException .class , () -> cancelledFuture .actionGet (0L ));
425
434
}
426
435
436
+ private static final ThreadContext testThreadContext = new ThreadContext (Settings .EMPTY );
437
+
427
438
private static class TestCache extends CancellableSingleObjectCache <String , String , Integer > {
428
439
429
440
private final LinkedList <StepListener <Function <String , Integer >>> pendingRefreshes = new LinkedList <>();
430
441
442
+ private TestCache () {
443
+ super (testThreadContext );
444
+ }
445
+
431
446
@ Override
432
447
protected void refresh (
433
448
String input ,
0 commit comments