Skip to content

Commit 0a34b71

Browse files
authored
Schedule commands in current thread context (#54187)
Changes ThreadPool's schedule method to run the schedule task in the context of the thread that scheduled the task. This is the more sensible default for this method, and eliminates a range of bugs where the current thread context is mistakenly dropped. Closes #17143
1 parent 1c48214 commit 0a34b71

File tree

9 files changed

+53
-54
lines changed

9 files changed

+53
-54
lines changed

server/src/main/java/org/elasticsearch/action/bulk/BulkProcessor.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,7 @@ public boolean isCancelled() {
420420
}
421421
};
422422
}
423-
final Runnable flushRunnable = scheduler.preserveContext(new Flush());
424-
return scheduler.scheduleWithFixedDelay(flushRunnable, flushInterval, ThreadPool.Names.GENERIC);
423+
return scheduler.scheduleWithFixedDelay(new Flush(), flushInterval, ThreadPool.Names.GENERIC);
425424
}
426425

427426
// needs to be executed under a lock

server/src/main/java/org/elasticsearch/action/bulk/Retry.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ private void retry(BulkRequest bulkRequestForRetry) {
136136
assert backoff.hasNext();
137137
TimeValue next = backoff.next();
138138
logger.trace("Retry of bulk request scheduled in {} ms.", next.millis());
139-
Runnable command = scheduler.preserveContext(() -> this.execute(bulkRequestForRetry));
140-
retryCancellable = scheduler.schedule(command, next, ThreadPool.Names.SAME);
139+
retryCancellable = scheduler.schedule(() -> this.execute(bulkRequestForRetry), next, ThreadPool.Names.SAME);
141140
}
142141

143142
private BulkRequest createBulkRequestForRetry(BulkResponse bulkItemResponses) {

server/src/main/java/org/elasticsearch/index/reindex/RetryListener.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ public void onRejection(Exception e) {
7272
}
7373

7474
private void schedule(Runnable runnable, TimeValue delay) {
75-
// schedule does not preserve context so have to do this manually
76-
threadPool.schedule(threadPool.preserveContext(runnable), delay, ThreadPool.Names.SAME);
75+
threadPool.schedule(runnable, delay, ThreadPool.Names.SAME);
7776
}
7877
}

server/src/main/java/org/elasticsearch/threadpool/Scheduler.java

+1-11
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,7 @@ static boolean awaitTermination(final ScheduledThreadPoolExecutor scheduledThrea
8383
}
8484

8585
/**
86-
* Does nothing by default but can be used by subclasses to save the current thread context and wraps the command in a Runnable
87-
* that restores that context before running the command.
88-
*/
89-
default Runnable preserveContext(Runnable command) {
90-
return command;
91-
}
92-
93-
/**
94-
* Schedules a one-shot command to be run after a given delay. The command is not run in the context of the calling thread.
95-
* To preserve the context of the calling thread you may call {@link #preserveContext(Runnable)} on the runnable before passing
96-
* it to this method.
86+
* Schedules a one-shot command to be run after a given delay. The command is run in the context of the calling thread.
9787
* The command runs on scheduler thread. Do not run blocking calls on the scheduler thread. Subclasses may allow
9888
* to execute on a different executor, in which case blocking calls are allowed.
9989
*

server/src/main/java/org/elasticsearch/threadpool/ThreadPool.java

+2-8
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,7 @@ public ExecutorService executor(String name) {
323323
}
324324

325325
/**
326-
* Schedules a one-shot command to run after a given delay. The command is not run in the context of the calling thread. To preserve the
327-
* context of the calling thread you may call <code>threadPool.getThreadContext().preserveContext</code> on the runnable before passing
328-
* it to this method.
326+
* Schedules a one-shot command to run after a given delay. The command is run in the context of the calling thread.
329327
*
330328
* @param command the command to run
331329
* @param delay delay before the task executes
@@ -339,6 +337,7 @@ public ExecutorService executor(String name) {
339337
*/
340338
@Override
341339
public ScheduledCancellable schedule(Runnable command, TimeValue delay, String executor) {
340+
command = threadContext.preserveContext(command);
342341
if (!Names.SAME.equals(executor)) {
343342
command = new ThreadedRunnable(command, executor(executor));
344343
}
@@ -371,11 +370,6 @@ public Cancellable scheduleWithFixedDelay(Runnable command, TimeValue interval,
371370
command, executor), e));
372371
}
373372

374-
@Override
375-
public Runnable preserveContext(Runnable command) {
376-
return getThreadContext().preserveContext(command);
377-
}
378-
379373
protected final void stopCachedTimeThread() {
380374
cachedTimeThread.running = false;
381375
cachedTimeThread.interrupt();

server/src/test/java/org/elasticsearch/threadpool/ThreadPoolTests.java

+33
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
package org.elasticsearch.threadpool;
2121

2222
import org.elasticsearch.common.settings.Settings;
23+
import org.elasticsearch.common.unit.TimeValue;
2324
import org.elasticsearch.common.util.concurrent.EsExecutors;
2425
import org.elasticsearch.common.util.concurrent.FutureUtils;
2526
import org.elasticsearch.test.ESTestCase;
2627

28+
import java.util.concurrent.CountDownLatch;
2729
import java.util.concurrent.ExecutorService;
2830

2931
import static org.elasticsearch.threadpool.ThreadPool.ESTIMATED_TIME_INTERVAL_SETTING;
@@ -102,4 +104,35 @@ public void testAssertCurrentMethodIsNotCalledRecursively() {
102104
equalTo("org.elasticsearch.threadpool.ThreadPoolTests#factorialForked is called recursively"));
103105
terminate(threadPool);
104106
}
107+
108+
public void testInheritContextOnSchedule() throws InterruptedException {
109+
final CountDownLatch latch = new CountDownLatch(1);
110+
final CountDownLatch executed = new CountDownLatch(1);
111+
112+
TestThreadPool threadPool = new TestThreadPool("test");
113+
try {
114+
threadPool.getThreadContext().putHeader("foo", "bar");
115+
final Integer one = Integer.valueOf(1);
116+
threadPool.getThreadContext().putTransient("foo", one);
117+
threadPool.schedule(() -> {
118+
try {
119+
latch.await();
120+
} catch (InterruptedException e) {
121+
fail();
122+
}
123+
assertEquals(threadPool.getThreadContext().getHeader("foo"), "bar");
124+
assertSame(threadPool.getThreadContext().getTransient("foo"), one);
125+
assertNull(threadPool.getThreadContext().getHeader("bar"));
126+
assertNull(threadPool.getThreadContext().getTransient("bar"));
127+
executed.countDown();
128+
}, TimeValue.timeValueMillis(randomInt(100)), randomFrom(ThreadPool.Names.SAME, ThreadPool.Names.GENERIC));
129+
threadPool.getThreadContext().putTransient("bar", "boom");
130+
threadPool.getThreadContext().putHeader("bar", "boom");
131+
latch.countDown();
132+
executed.await();
133+
} finally {
134+
latch.countDown();
135+
terminate(threadPool);
136+
}
137+
}
105138
}

test/framework/src/main/java/org/elasticsearch/cluster/coordination/DeterministicTaskQueue.java

-5
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,6 @@ public Cancellable scheduleWithFixedDelay(Runnable command, TimeValue interval,
381381
return super.scheduleWithFixedDelay(command, interval, executor);
382382
}
383383

384-
@Override
385-
public Runnable preserveContext(Runnable command) {
386-
return command;
387-
}
388-
389384
@Override
390385
public void shutdown() {
391386
throw new UnsupportedOperationException();

x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,13 @@ private void internalAddCompletionListener(ActionListener<AsyncSearchResponse> l
217217

218218
final Cancellable cancellable;
219219
try {
220-
cancellable = threadPool.schedule(threadPool.preserveContext(() -> {
220+
cancellable = threadPool.schedule(() -> {
221221
if (hasRun.compareAndSet(false, true)) {
222222
// timeout occurred before completion
223223
removeCompletionListener(id);
224224
listener.onResponse(getResponse());
225225
}
226-
}), waitForCompletion, "generic");
226+
}, waitForCompletion, "generic");
227227
} catch (EsRejectedExecutionException exc) {
228228
listener.onFailure(exc);
229229
return;

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java

+12-22
Original file line numberDiff line numberDiff line change
@@ -786,10 +786,8 @@ private void indexInvalidation(Collection<String> tokenIds, SecurityIndexManager
786786
retryTokenDocIds.size(), tokenIds.size());
787787
final TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated,
788788
previouslyInvalidated, failedRequestResponses);
789-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
790-
.preserveContext(() -> indexInvalidation(retryTokenDocIds, tokensIndexManager, backoff,
791-
srcPrefix, incompleteResult, listener));
792-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
789+
client.threadPool().schedule(() -> indexInvalidation(retryTokenDocIds, tokensIndexManager, backoff,
790+
srcPrefix, incompleteResult, listener), backoff.next(), GENERIC);
793791
} else {
794792
if (retryTokenDocIds.isEmpty() == false) {
795793
logger.warn("failed to invalidate [{}] tokens out of [{}] after all retries", retryTokenDocIds.size(),
@@ -809,10 +807,8 @@ private void indexInvalidation(Collection<String> tokenIds, SecurityIndexManager
809807
traceLog("invalidate tokens", cause);
810808
if (isShardNotAvailableException(cause) && backoff.hasNext()) {
811809
logger.debug("failed to invalidate tokens, retrying ");
812-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
813-
.preserveContext(() -> indexInvalidation(tokenIds, tokensIndexManager, backoff, srcPrefix,
814-
previousResult, listener));
815-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
810+
client.threadPool().schedule(() -> indexInvalidation(tokenIds, tokensIndexManager, backoff, srcPrefix,
811+
previousResult, listener), backoff.next(), GENERIC);
816812
} else {
817813
listener.onFailure(e);
818814
}
@@ -894,9 +890,8 @@ private void findTokenFromRefreshToken(String refreshToken, SecurityIndexManager
894890
if (backoff.hasNext()) {
895891
final TimeValue backofTimeValue = backoff.next();
896892
logger.debug("retrying after [{}] back off", backofTimeValue);
897-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
898-
.preserveContext(() -> findTokenFromRefreshToken(refreshToken, tokensIndexManager, backoff, listener));
899-
client.threadPool().schedule(retryWithContextRunnable, backofTimeValue, GENERIC);
893+
client.threadPool().schedule(() -> findTokenFromRefreshToken(refreshToken, tokensIndexManager, backoff, listener),
894+
backofTimeValue, GENERIC);
900895
} else {
901896
logger.warn("failed to find token from refresh token after all retries");
902897
onFailure.accept(ex);
@@ -1018,10 +1013,8 @@ private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Ob
10181013
} else if (backoff.hasNext()) {
10191014
logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying",
10201015
tokenDocId, updateResponse.getResult());
1021-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
1022-
.preserveContext(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm, clientAuth,
1023-
backoff, refreshRequested, listener));
1024-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
1016+
client.threadPool().schedule(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm,
1017+
clientAuth, backoff, refreshRequested, listener), backoff.next(), GENERIC);
10251018
} else {
10261019
logger.info("failed to update the original token document [{}] after all retries, the update result was [{}]. ",
10271020
tokenDocId, updateResponse.getResult());
@@ -1049,9 +1042,8 @@ public void onFailure(Exception e) {
10491042
if (isShardNotAvailableException(e)) {
10501043
if (backoff.hasNext()) {
10511044
logger.info("could not get token document [{}] for refresh, retrying", tokenDocId);
1052-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
1053-
.preserveContext(() -> getTokenDocAsync(tokenDocId, refreshedTokenIndex, this));
1054-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
1045+
client.threadPool().schedule(() -> getTokenDocAsync(tokenDocId, refreshedTokenIndex, this),
1046+
backoff.next(), GENERIC);
10551047
} else {
10561048
logger.warn("could not get token document [{}] for refresh after all retries", tokenDocId);
10571049
onFailure.accept(invalidGrantException("could not refresh the requested token"));
@@ -1064,10 +1056,8 @@ public void onFailure(Exception e) {
10641056
} else if (isShardNotAvailableException(e)) {
10651057
if (backoff.hasNext()) {
10661058
logger.debug("failed to update the original token document [{}], retrying", tokenDocId);
1067-
final Runnable retryWithContextRunnable = client.threadPool().getThreadContext()
1068-
.preserveContext(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm,
1069-
clientAuth, backoff, refreshRequested, listener));
1070-
client.threadPool().schedule(retryWithContextRunnable, backoff.next(), GENERIC);
1059+
client.threadPool().schedule(() -> innerRefresh(refreshToken, tokenDocId, source, seqNo, primaryTerm,
1060+
clientAuth, backoff, refreshRequested, listener), backoff.next(), GENERIC);
10711061
} else {
10721062
logger.warn("failed to update the original token document [{}], after all retries", tokenDocId);
10731063
onFailure.accept(invalidGrantException("could not refresh the requested token"));

0 commit comments

Comments
 (0)