Skip to content

Commit 1ba6783

Browse files
committed
Schedule commands in current thread context (#54187)
Changes ThreadPool's schedule method to run the schedule task in the context of the thread that scheduled the task. This is the more sensible default for this method, and eliminates a range of bugs where the current thread context is mistakenly dropped. Closes #17143
1 parent 32d0bb8 commit 1ba6783

File tree

9 files changed

+53
-54
lines changed

9 files changed

+53
-54
lines changed

server/src/main/java/org/elasticsearch/action/bulk/BulkProcessor.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -435,8 +435,7 @@ public boolean isCancelled() {
435435
}
436436
};
437437
}
438-
final Runnable flushRunnable = scheduler.preserveContext(new Flush());
439-
return scheduler.scheduleWithFixedDelay(flushRunnable, flushInterval, ThreadPool.Names.GENERIC);
438+
return scheduler.scheduleWithFixedDelay(new Flush(), flushInterval, ThreadPool.Names.GENERIC);
440439
}
441440

442441
// needs to be executed under a lock

server/src/main/java/org/elasticsearch/action/bulk/Retry.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ private void retry(BulkRequest bulkRequestForRetry) {
136136
assert backoff.hasNext();
137137
TimeValue next = backoff.next();
138138
logger.trace("Retry of bulk request scheduled in {} ms.", next.millis());
139-
Runnable command = scheduler.preserveContext(() -> this.execute(bulkRequestForRetry));
140-
retryCancellable = scheduler.schedule(command, next, ThreadPool.Names.SAME);
139+
retryCancellable = scheduler.schedule(() -> this.execute(bulkRequestForRetry), next, ThreadPool.Names.SAME);
141140
}
142141

143142
private BulkRequest createBulkRequestForRetry(BulkResponse bulkItemResponses) {

server/src/main/java/org/elasticsearch/index/reindex/RetryListener.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ public void onRejection(Exception e) {
7272
}
7373

7474
private void schedule(Runnable runnable, TimeValue delay) {
75-
// schedule does not preserve context so have to do this manually
76-
threadPool.schedule(threadPool.preserveContext(runnable), delay, ThreadPool.Names.SAME);
75+
threadPool.schedule(runnable, delay, ThreadPool.Names.SAME);
7776
}
7877
}

server/src/main/java/org/elasticsearch/threadpool/Scheduler.java

+1-11
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,7 @@ static boolean awaitTermination(final ScheduledThreadPoolExecutor scheduledThrea
8383
}
8484

8585
/**
86-
* Does nothing by default but can be used by subclasses to save the current thread context and wraps the command in a Runnable
87-
* that restores that context before running the command.
88-
*/
89-
default Runnable preserveContext(Runnable command) {
90-
return command;
91-
}
92-
93-
/**
94-
* Schedules a one-shot command to be run after a given delay. The command is not run in the context of the calling thread.
95-
* To preserve the context of the calling thread you may call {@link #preserveContext(Runnable)} on the runnable before passing
96-
* it to this method.
86+
* Schedules a one-shot command to be run after a given delay. The command is run in the context of the calling thread.
9787
* The command runs on scheduler thread. Do not run blocking calls on the scheduler thread. Subclasses may allow
9888
* to execute on a different executor, in which case blocking calls are allowed.
9989
*

server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java

+2-8
Original file line numberDiff line numberDiff line change
@@ -335,9 +335,7 @@ public ExecutorService executor(String name) {
335335
}
336336

337337
/**
338-
* Schedules a one-shot command to run after a given delay. The command is not run in the context of the calling thread. To preserve the
339-
* context of the calling thread you may call <code>threadPool.getThreadContext().preserveContext</code> on the runnable before passing
340-
* it to this method.
338+
* Schedules a one-shot command to run after a given delay. The command is run in the context of the calling thread.
341339
*
342340
* @param command the command to run
343341
* @param delay delay before the task executes
@@ -351,6 +349,7 @@ public ExecutorService executor(String name) {
351349
*/
352350
@Override
353351
public ScheduledCancellable schedule(Runnable command, TimeValue delay, String executor) {
352+
command = threadContext.preserveContext(command);
354353
if (!Names.SAME.equals(executor)) {
355354
command = new ThreadedRunnable(command, executor(executor));
356355
}
@@ -383,11 +382,6 @@ public Cancellable scheduleWithFixedDelay(Runnable command, TimeValue interval,
383382
command, executor), e));
384383
}
385384

386-
@Override
387-
public Runnable preserveContext(Runnable command) {
388-
return getThreadContext().preserveContext(command);
389-
}
390-
391385
protected final void stopCachedTimeThread() {
392386
cachedTimeThread.running = false;
393387
cachedTimeThread.interrupt();

server/src/test/java/org/elasticsearch/threadpool/ThreadPoolTests.java

+33
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
package org.elasticsearch.threadpool;
2121

2222
import org.elasticsearch.common.settings.Settings;
23+
import org.elasticsearch.common.unit.TimeValue;
2324
import org.elasticsearch.common.util.concurrent.EsExecutors;
2425
import org.elasticsearch.common.util.concurrent.FutureUtils;
2526
import org.elasticsearch.test.ESTestCase;
2627

28+
import java.util.concurrent.CountDownLatch;
2729
import java.util.concurrent.ExecutorService;
2830

2931
import static org.elasticsearch.threadpool.ThreadPool.ESTIMATED_TIME_INTERVAL_SETTING;
@@ -102,4 +104,35 @@ public void testAssertCurrentMethodIsNotCalledRecursively() {
102104
equalTo("org.elasticsearch.threadpool.ThreadPoolTests#factorialForked is called recursively"));
103105
terminate(threadPool);
104106
}
107+
108+
public void testInheritContextOnSchedule() throws InterruptedException {
109+
final CountDownLatch latch = new CountDownLatch(1);
110+
final CountDownLatch executed = new CountDownLatch(1);
111+
112+
TestThreadPool threadPool = new TestThreadPool("test");
113+
try {
114+
threadPool.getThreadContext().putHeader("foo", "bar");
115+
final Integer one = Integer.valueOf(1);
116+
threadPool.getThreadContext().putTransient("foo", one);
117+
threadPool.schedule(() -> {
118+
try {
119+
latch.await();
120+
} catch (InterruptedException e) {
121+
fail();
122+
}
123+
assertEquals(threadPool.getThreadContext().getHeader("foo"), "bar");
124+
assertSame(threadPool.getThreadContext().getTransient("foo"), one);
125+
assertNull(threadPool.getThreadContext().getHeader("bar"));
126+
assertNull(threadPool.getThreadContext().getTransient("bar"));
127+
executed.countDown();
128+
}, TimeValue.timeValueMillis(randomInt(100)), randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC));
129+
threadPool.getThreadContext().putTransient("bar", "boom");
130+
threadPool.getThreadContext().putHeader("bar", "boom");
131+
latch.countDown();
132+
executed.await();
133+
} finally {
134+
latch.countDown();
135+
terminate(threadPool);
136+
}
137+
}
105138
}

test/framework/src/main/java/org/elasticsearch/cluster/coordination/DeterministicTaskQueue.java

-5
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,6 @@ public Cancellable scheduleWithFixedDelay(Runnable command, TimeValue interval,
381381
return super.scheduleWithFixedDelay(command, interval, executor);
382382
}
383383

384-
@Override
385-
public Runnable preserveContext(Runnable command) {
386-
return command;
387-
}
388-
389384
@Override
390385
public void shutdown() {
391386
throw new UnsupportedOperationException();

x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,13 @@ private void internalAddCompletionListener(ActionListener<AsyncSearchResponse> l
217217

218218
final Cancellable cancellable;
219219
try {
220-
cancellable = threadPool.schedule(threadPool.preserveContext(() -> {
220+
cancellable = threadPool.schedule(() -> {
221221
if (hasRun.compareAndSet(false, true)) {
222222
// timeout occurred before completion
223223
removeCompletionListener(id);
224224
listener.onResponse(getResponse());
225225
}
226-
}), waitForCompletion, "generic");
226+
}, waitForCompletion, "generic");
227227
} catch (EsRejectedExecutionException exc) {
228228
listener.onFailure(exc);
229229
return;

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java

+12-22
Original file line numberDiff line numberDiff line change
@@ -787,10 +787,8 @@ private void indexInvalidation(Collection<String> tokenIds, SecurityIndexManager
787787
retryTokenDocIds.size(), tokenIds.size());
788788
final TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated,
789789
previouslyInvalidated, failedRequestResponses);
790-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
791-
.preserveContext(() -> indexInvalidation(retryTokenDocIds, tokensIndexManager, backoff,
792-
srcPrefix, incompleteResult, listener));
793-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
790+
client.threadPool().schedule(() -> indexInvalidation(retryTokenDocIds, tokensIndexManager, backoff,
791+
srcPrefix, incompleteResult, listener), backoff.next(), GENERIC);
794792
} else {
795793
if (retryTokenDocIds.isEmpty() == false) {
796794
logger.warn("failed to invalidate [{}] tokens out of [{}] after all retries", retryTokenDocIds.size(),
@@ -810,10 +808,8 @@ private void indexInvalidation(Collection<String> tokenIds, SecurityIndexManager
810808
traceLog("invalidate tokens", cause);
811809
if (isShardNotAvailableException(cause) && backoff.hasNext()) {
812810
logger.debug("failed to invalidate tokens, retrying ");
813-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
814-
.preserveContext(() -> indexInvalidation(tokenIds, tokensIndexManager, backoff, srcPrefix,
815-
previousResult, listener));
816-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
811+
client.threadPool().schedule(() -> indexInvalidation(tokenIds, tokensIndexManager, backoff, srcPrefix,
812+
previousResult, listener), backoff.next(), GENERIC);
817813
} else {
818814
listener.onFailure(e);
819815
}
@@ -895,9 +891,8 @@ private void findTokenFromRefreshToken(String refreshToken, SecurityIndexManager
895891
if (backoff.hasNext()) {
896892
final TimeValue backofTimeValue = backoff.next();
897893
logger.debug("retrying after [{}] back off", backofTimeValue);
898-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
899-
.preserveContext(() -> findTokenFromRefreshToken(refreshToken, tokensIndexManager, backoff, listener));
900-
client.threadPool().schedule(retryWithContextRunnable, backofTimeValue, GENERIC);
894+
client.threadPool().schedule(() -> findTokenFromRefreshToken(refreshToken, tokensIndexManager, backoff, listener),
895+
backofTimeValue, GENERIC);
901896
} else {
902897
logger.warn("failed to find token from refresh token after all retries");
903898
onFailure.accept(ex);
@@ -1019,10 +1014,8 @@ private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Ob
10191014
} else if (backoff.hasNext()) {
10201015
logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying",
10211016
tokenDocId, updateResponse.getResult());
1022-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
1023-
.preserveContext(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm, clientAuth,
1024-
backoff, refreshRequested, listener));
1025-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
1017+
client.threadPool().schedule(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm,
1018+
clientAuth, backoff, refreshRequested, listener), backoff.next(), GENERIC);
10261019
} else {
10271020
logger.info("failed to update the original token document [{}] after all retries, the update result was [{}]. ",
10281021
tokenDocId, updateResponse.getResult());
@@ -1050,9 +1043,8 @@ public void onFailure(Exception e) {
10501043
if (isShardNotAvailableException(e)) {
10511044
if (backoff.hasNext()) {
10521045
logger.info("could not get token document [{}] for refresh, retrying", tokenDocId);
1053-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
1054-
.preserveContext(() -> getTokenDocAsync(tokenDocId, refreshedTokenIndex, this));
1055-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
1046+
client.threadPool().schedule(() -> getTokenDocAsync(tokenDocId, refreshedTokenIndex, this),
1047+
backoff.next(), GENERIC);
10561048
} else {
10571049
logger.warn("could not get token document [{}] for refresh after all retries", tokenDocId);
10581050
onFailure.accept(invalidGrantException("could not refresh the requested token"));
@@ -1065,10 +1057,8 @@ public void onFailure(Exception e) {
10651057
} else if (isShardNotAvailableException(e)) {
10661058
if (backoff.hasNext()) {
10671059
logger.debug("failed to update the original token document [{}], retrying", tokenDocId);
1068-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
1069-
.preserveContext(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm,
1070-
clientAuth, backoff, refreshRequested, listener));
1071-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
1060+
client.threadPool().schedule(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm,
1061+
clientAuth, backoff, refreshRequested, listener), backoff.next(), GENERIC);
10721062
} else {
10731063
logger.warn("failed to update the original token document [{}], after all retries", tokenDocId);
10741064
onFailure.accept(invalidGrantException("could not refresh the requested token"));

0 commit comments

Comments
 (0)