Skip to content

Commit db99325

Browse files
committed
ListenableFuture should preserve ThreadContext (#34394)
ListenableFuture may run a listener on the same thread that called the addListener method or it may execute on another thread after the future has completed. Whenever the ListenableFuture stores the listener for execution later, it should preserve the thread context which is what this change does.
1 parent 610e49c commit db99325

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.elasticsearch.common.util.concurrent;
2121

2222
import org.elasticsearch.action.ActionListener;
23+
import org.elasticsearch.action.support.ContextPreservingActionListener;
2324
import org.elasticsearch.common.collect.Tuple;
2425

2526
import java.util.ArrayList;
@@ -47,7 +48,7 @@ public final class ListenableFuture<V> extends BaseFuture<V> implements ActionLi
4748
* If the future has completed, the listener will be notified immediately without forking to
4849
* a different thread.
4950
*/
50-
public void addListener(ActionListener<V> listener, ExecutorService executor) {
51+
public void addListener(ActionListener<V> listener, ExecutorService executor, ThreadContext threadContext) {
5152
if (done) {
5253
// run the callback directly, we don't hold the lock and don't need to fork!
5354
notifyListener(listener, EsExecutors.newDirectExecutorService());
@@ -59,7 +60,7 @@ public void addListener(ActionListener<V> listener, ExecutorService executor) {
5960
if (done) {
6061
run = true;
6162
} else {
62-
listeners.add(new Tuple<>(listener, executor));
63+
listeners.add(new Tuple<>(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext), executor));
6364
run = false;
6465
}
6566
}

server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.elasticsearch.common.util.concurrent;
2121

22+
import org.apache.logging.log4j.message.ParameterizedMessage;
2223
import org.elasticsearch.action.ActionListener;
2324
import org.elasticsearch.common.settings.Settings;
2425
import org.elasticsearch.test.ESTestCase;
@@ -30,9 +31,12 @@
3031
import java.util.concurrent.ExecutorService;
3132
import java.util.concurrent.atomic.AtomicInteger;
3233

34+
import static org.hamcrest.Matchers.is;
35+
3336
public class ListenableFutureTests extends ESTestCase {
3437

3538
private ExecutorService executorService;
39+
private ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
3640

3741
@After
3842
public void stopExecutorService() throws InterruptedException {
@@ -46,7 +50,7 @@ public void testListenableFutureNotifiesListeners() {
4650
AtomicInteger notifications = new AtomicInteger(0);
4751
final int numberOfListeners = scaledRandomIntBetween(1, 12);
4852
for (int i = 0; i < numberOfListeners; i++) {
49-
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService());
53+
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService(), threadContext);
5054
}
5155

5256
future.onResponse("");
@@ -63,7 +67,7 @@ public void testListenableFutureNotifiesListenersOnException() {
6367
future.addListener(ActionListener.wrap(s -> fail("this should never be called"), e -> {
6468
assertEquals(exception, e);
6569
notifications.incrementAndGet();
66-
}), EsExecutors.newDirectExecutorService());
70+
}), EsExecutors.newDirectExecutorService(), threadContext);
6771
}
6872

6973
future.onFailure(exception);
@@ -76,7 +80,7 @@ public void testConcurrentListenerRegistrationAndCompletion() throws BrokenBarri
7680
final int completingThread = randomIntBetween(0, numberOfThreads - 1);
7781
final ListenableFuture<String> future = new ListenableFuture<>();
7882
executorService = EsExecutors.newFixed("testConcurrentListenerRegistrationAndCompletion", numberOfThreads, 1000,
79-
EsExecutors.daemonThreadFactory("listener"), new ThreadContext(Settings.EMPTY));
83+
EsExecutors.daemonThreadFactory("listener"), threadContext);
8084
final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
8185
final CountDownLatch listenersLatch = new CountDownLatch(numberOfThreads - 1);
8286
final AtomicInteger numResponses = new AtomicInteger(0);
@@ -85,20 +89,31 @@ public void testConcurrentListenerRegistrationAndCompletion() throws BrokenBarri
8589
for (int i = 0; i < numberOfThreads; i++) {
8690
final int threadNum = i;
8791
Thread thread = new Thread(() -> {
92+
threadContext.putTransient("key", threadNum);
8893
try {
8994
barrier.await();
9095
if (threadNum == completingThread) {
96+
// we need to do more than just call onResponse as this often results in synchronous
97+
// execution of the listeners instead of actually going async
98+
final int waitTime = randomIntBetween(0, 50);
99+
Thread.sleep(waitTime);
100+
logger.info("completing the future after sleeping {}ms", waitTime);
91101
future.onResponse("");
102+
logger.info("future received response");
92103
} else {
104+
logger.info("adding listener {}", threadNum);
93105
future.addListener(ActionListener.wrap(s -> {
106+
logger.info("listener {} received value {}", threadNum, s);
94107
assertEquals("", s);
108+
assertThat(threadContext.getTransient("key"), is(threadNum));
95109
numResponses.incrementAndGet();
96110
listenersLatch.countDown();
97111
}, e -> {
98-
logger.error("caught unexpected exception", e);
112+
logger.error(new ParameterizedMessage("listener {} caught unexpected exception", threadNum), e);
99113
numExceptions.incrementAndGet();
100114
listenersLatch.countDown();
101-
}), executorService);
115+
}), executorService, threadContext);
116+
logger.info("listener {} added", threadNum);
102117
}
103118
barrier.await();
104119
} catch (InterruptedException | BrokenBarrierException e) {

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ private void authenticateWithCache(UsernamePasswordToken token, ActionListener<A
122122
listener);
123123
}
124124
}, e -> handleFailure(future, createdAndStartedFuture.get(), token, e, listener)),
125-
threadPool.executor(ThreadPool.Names.GENERIC));
125+
threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
126126
} catch (ExecutionException e) {
127127
listener.onResponse(AuthenticationResult.unsuccessful("", e));
128128
}
@@ -220,7 +220,7 @@ public final void lookupUser(String username, ActionListener<User> listener) {
220220
} else {
221221
listener.onResponse(null);
222222
}
223-
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC));
223+
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
224224
} catch (ExecutionException e) {
225225
listener.onFailure(e);
226226
}

x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,14 +439,17 @@ protected void doLookupUser(String username, ActionListener<User> listener) {
439439
List<Thread> threads = new ArrayList<>(numberOfThreads);
440440
for (int i = 0; i < numberOfThreads; i++) {
441441
final boolean invalidPassword = randomBoolean();
442+
final int threadNum = i;
442443
threads.add(new Thread(() -> {
444+
threadPool.getThreadContext().putTransient("key", threadNum);
443445
try {
444446
latch.countDown();
445447
latch.await();
446448
for (int i1 = 0; i1 < numberOfIterations; i1++) {
447449
UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password);
448450

449451
realm.authenticate(token, ActionListener.wrap((result) -> {
452+
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
450453
if (invalidPassword && result.isAuthenticated()) {
451454
throw new RuntimeException("invalid password led to an authenticated user: " + result);
452455
} else if (invalidPassword == false && result.isAuthenticated() == false) {
@@ -499,12 +502,15 @@ protected void doLookupUser(String username, ActionListener<User> listener) {
499502
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
500503
List<Thread> threads = new ArrayList<>(numberOfThreads);
501504
for (int i = 0; i < numberOfThreads; i++) {
505+
final int threadNum = i;
502506
threads.add(new Thread(() -> {
503507
try {
508+
threadPool.getThreadContext().putTransient("key", threadNum);
504509
latch.countDown();
505510
latch.await();
506511
for (int i1 = 0; i1 < numberOfIterations; i1++) {
507512
realm.lookupUser(username, ActionListener.wrap((user) -> {
513+
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
508514
if (user == null) {
509515
throw new RuntimeException("failed to lookup user");
510516
}

0 commit comments

Comments
 (0)