Skip to content

Commit a5473c3

Browse files
committed
ListenableFuture should preserve ThreadContext (#34394)
ListenableFuture may run a listener on the same thread that called the addListener method or it may execute on another thread after the future has completed. Whenever the ListenableFuture stores the listener for execution later, it should preserve the thread context which is what this change does.
1 parent abd24f9 commit a5473c3

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.elasticsearch.common.util.concurrent;
2121

2222
import org.elasticsearch.action.ActionListener;
23+
import org.elasticsearch.action.support.ContextPreservingActionListener;
2324
import org.elasticsearch.common.collect.Tuple;
2425

2526
import java.util.ArrayList;
@@ -47,7 +48,7 @@ public final class ListenableFuture<V> extends BaseFuture<V> implements ActionLi
4748
* If the future has completed, the listener will be notified immediately without forking to
4849
* a different thread.
4950
*/
50-
public void addListener(ActionListener<V> listener, ExecutorService executor) {
51+
public void addListener(ActionListener<V> listener, ExecutorService executor, ThreadContext threadContext) {
5152
if (done) {
5253
// run the callback directly, we don't hold the lock and don't need to fork!
5354
notifyListener(listener, EsExecutors.newDirectExecutorService());
@@ -59,7 +60,7 @@ public void addListener(ActionListener<V> listener, ExecutorService executor) {
5960
if (done) {
6061
run = true;
6162
} else {
62-
listeners.add(new Tuple<>(listener, executor));
63+
listeners.add(new Tuple<>(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext), executor));
6364
run = false;
6465
}
6566
}

server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.elasticsearch.common.util.concurrent;
2121

22+
import org.apache.logging.log4j.message.ParameterizedMessage;
2223
import org.elasticsearch.action.ActionListener;
2324
import org.elasticsearch.common.settings.Settings;
2425
import org.elasticsearch.test.ESTestCase;
@@ -30,9 +31,12 @@
3031
import java.util.concurrent.ExecutorService;
3132
import java.util.concurrent.atomic.AtomicInteger;
3233

34+
import static org.hamcrest.Matchers.is;
35+
3336
public class ListenableFutureTests extends ESTestCase {
3437

3538
private ExecutorService executorService;
39+
private ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
3640

3741
@After
3842
public void stopExecutorService() throws InterruptedException {
@@ -46,7 +50,7 @@ public void testListenableFutureNotifiesListeners() {
4650
AtomicInteger notifications = new AtomicInteger(0);
4751
final int numberOfListeners = scaledRandomIntBetween(1, 12);
4852
for (int i = 0; i < numberOfListeners; i++) {
49-
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService());
53+
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService(), threadContext);
5054
}
5155

5256
future.onResponse("");
@@ -63,7 +67,7 @@ public void testListenableFutureNotifiesListenersOnException() {
6367
future.addListener(ActionListener.wrap(s -> fail("this should never be called"), e -> {
6468
assertEquals(exception, e);
6569
notifications.incrementAndGet();
66-
}), EsExecutors.newDirectExecutorService());
70+
}), EsExecutors.newDirectExecutorService(), threadContext);
6771
}
6872

6973
future.onFailure(exception);
@@ -76,7 +80,7 @@ public void testConcurrentListenerRegistrationAndCompletion() throws BrokenBarri
7680
final int completingThread = randomIntBetween(0, numberOfThreads - 1);
7781
final ListenableFuture<String> future = new ListenableFuture<>();
7882
executorService = EsExecutors.newFixed("testConcurrentListenerRegistrationAndCompletion", numberOfThreads, 1000,
79-
EsExecutors.daemonThreadFactory("listener"), new ThreadContext(Settings.EMPTY));
83+
EsExecutors.daemonThreadFactory("listener"), threadContext);
8084
final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
8185
final CountDownLatch listenersLatch = new CountDownLatch(numberOfThreads - 1);
8286
final AtomicInteger numResponses = new AtomicInteger(0);
@@ -85,20 +89,31 @@ public void testConcurrentListenerRegistrationAndCompletion() throws BrokenBarri
8589
for (int i = 0; i < numberOfThreads; i++) {
8690
final int threadNum = i;
8791
Thread thread = new Thread(() -> {
92+
threadContext.putTransient("key", threadNum);
8893
try {
8994
barrier.await();
9095
if (threadNum == completingThread) {
96+
// we need to do more than just call onResponse as this often results in synchronous
97+
// execution of the listeners instead of actually going async
98+
final int waitTime = randomIntBetween(0, 50);
99+
Thread.sleep(waitTime);
100+
logger.info("completing the future after sleeping {}ms", waitTime);
91101
future.onResponse("");
102+
logger.info("future received response");
92103
} else {
104+
logger.info("adding listener {}", threadNum);
93105
future.addListener(ActionListener.wrap(s -> {
106+
logger.info("listener {} received value {}", threadNum, s);
94107
assertEquals("", s);
108+
assertThat(threadContext.getTransient("key"), is(threadNum));
95109
numResponses.incrementAndGet();
96110
listenersLatch.countDown();
97111
}, e -> {
98-
logger.error("caught unexpected exception", e);
112+
logger.error(new ParameterizedMessage("listener {} caught unexpected exception", threadNum), e);
99113
numExceptions.incrementAndGet();
100114
listenersLatch.countDown();
101-
}), executorService);
115+
}), executorService, threadContext);
116+
logger.info("listener {} added", threadNum);
102117
}
103118
barrier.await();
104119
} catch (InterruptedException | BrokenBarrierException e) {

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ private void authenticateWithCache(UsernamePasswordToken token, ActionListener<A
132132
// is cleared of the failed authentication
133133
cache.invalidate(token.principal(), listenableCacheEntry);
134134
authenticateWithCache(token, listener);
135-
}), threadPool.executor(ThreadPool.Names.GENERIC));
135+
}), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
136136
} else {
137137
// attempt authentication against the authentication source
138138
doAuthenticate(token, ActionListener.wrap(authResult -> {
@@ -234,7 +234,7 @@ private void lookupWithCache(String username, ActionListener<User> listener) {
234234
} else {
235235
listener.onResponse(null);
236236
}
237-
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC));
237+
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
238238
} catch (final ExecutionException e) {
239239
listener.onFailure(e);
240240
}

x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,14 +468,17 @@ protected void doLookupUser(String username, ActionListener<User> listener) {
468468
List<Thread> threads = new ArrayList<>(numberOfThreads);
469469
for (int i = 0; i < numberOfThreads; i++) {
470470
final boolean invalidPassword = randomBoolean();
471+
final int threadNum = i;
471472
threads.add(new Thread(() -> {
473+
threadPool.getThreadContext().putTransient("key", threadNum);
472474
try {
473475
latch.countDown();
474476
latch.await();
475477
for (int i1 = 0; i1 < numberOfIterations; i1++) {
476478
UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password);
477479

478480
realm.authenticate(token, ActionListener.wrap((result) -> {
481+
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
479482
if (invalidPassword && result.isAuthenticated()) {
480483
throw new RuntimeException("invalid password led to an authenticated user: " + result);
481484
} else if (invalidPassword == false && result.isAuthenticated() == false) {
@@ -528,12 +531,15 @@ protected void doLookupUser(String username, ActionListener<User> listener) {
528531
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
529532
List<Thread> threads = new ArrayList<>(numberOfThreads);
530533
for (int i = 0; i < numberOfThreads; i++) {
534+
final int threadNum = i;
531535
threads.add(new Thread(() -> {
532536
try {
537+
threadPool.getThreadContext().putTransient("key", threadNum);
533538
latch.countDown();
534539
latch.await();
535540
for (int i1 = 0; i1 < numberOfIterations; i1++) {
536541
realm.lookupUser(username, ActionListener.wrap((user) -> {
542+
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
537543
if (user == null) {
538544
throw new RuntimeException("failed to lookup user");
539545
}

0 commit comments

Comments
 (0)