Skip to content

Commit 02a4ef3

Browse files
authored
Use system context for cluster state update tasks (#31241)
This commit makes it so that cluster state update tasks always run under the system context, only restoring the original context when the listener that was provided with the task is called. A notable exception is the clusterStatePublished(...) callback which will still run under system context, because it's defined on the executor-level, and not the task level, and only called once for the combined batch of tasks and can therefore not be uniquely identified with a task / thread context. Relates #30603
1 parent 1502812 commit 02a4ef3

File tree

19 files changed

+236
-92
lines changed

19 files changed

+236
-92
lines changed

server/src/main/java/org/elasticsearch/cluster/ClusterStateTaskExecutor.java

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ default boolean runOnlyOnMaster() {
4141
/**
4242
* Callback invoked after new cluster state is published. Note that
4343
* this method is not invoked if the cluster state was not updated.
44+
*
45+
* Note that this method will be executed using system context.
46+
*
4447
* @param clusterChangedEvent the change event for this cluster state change, containing
4548
* both old and new states
4649
*/

server/src/main/java/org/elasticsearch/cluster/ClusterStateUpdateTask.java

+6
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ public String describeTasks(List<ClusterStateUpdateTask> tasks) {
6262
*/
6363
public abstract void onFailure(String source, Exception e);
6464

65+
@Override
66+
public final void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) {
67+
// final, empty implementation here as this method should only be defined in combination
68+
// with a batching executor as it will always be executed within the system context.
69+
}
70+
6571
/**
6672
* If the cluster state update task wasn't processed by the provided timeout, call
6773
* {@link ClusterStateTaskListener#onFailure(String, Exception)}. May return null to indicate no timeout is needed (default).

server/src/main/java/org/elasticsearch/cluster/service/MasterService.java

+22-13
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
4848
import org.elasticsearch.common.util.concurrent.FutureUtils;
4949
import org.elasticsearch.common.util.concurrent.PrioritizedEsThreadPoolExecutor;
50+
import org.elasticsearch.common.util.concurrent.ThreadContext;
5051
import org.elasticsearch.discovery.Discovery;
5152
import org.elasticsearch.threadpool.ThreadPool;
5253

@@ -59,6 +60,7 @@
5960
import java.util.concurrent.Future;
6061
import java.util.concurrent.TimeUnit;
6162
import java.util.function.BiConsumer;
63+
import java.util.function.Supplier;
6264
import java.util.stream.Collectors;
6365

6466
import static org.elasticsearch.cluster.service.ClusterService.CLUSTER_SERVICE_SLOW_TASK_LOGGING_THRESHOLD_SETTING;
@@ -426,26 +428,28 @@ public TimeValue getMaxTaskWaitTime() {
426428
return threadPoolExecutor.getMaxTaskWaitTime();
427429
}
428430

429-
private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener) {
431+
private SafeClusterStateTaskListener safe(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> contextSupplier) {
430432
if (listener instanceof AckedClusterStateTaskListener) {
431-
return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, logger);
433+
return new SafeAckedClusterStateTaskListener((AckedClusterStateTaskListener) listener, contextSupplier, logger);
432434
} else {
433-
return new SafeClusterStateTaskListener(listener, logger);
435+
return new SafeClusterStateTaskListener(listener, contextSupplier, logger);
434436
}
435437
}
436438

437439
private static class SafeClusterStateTaskListener implements ClusterStateTaskListener {
438440
private final ClusterStateTaskListener listener;
441+
protected final Supplier<ThreadContext.StoredContext> context;
439442
private final Logger logger;
440443

441-
SafeClusterStateTaskListener(ClusterStateTaskListener listener, Logger logger) {
444+
SafeClusterStateTaskListener(ClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context, Logger logger) {
442445
this.listener = listener;
446+
this.context = context;
443447
this.logger = logger;
444448
}
445449

446450
@Override
447451
public void onFailure(String source, Exception e) {
448-
try {
452+
try (ThreadContext.StoredContext ignore = context.get()) {
449453
listener.onFailure(source, e);
450454
} catch (Exception inner) {
451455
inner.addSuppressed(e);
@@ -456,7 +460,7 @@ public void onFailure(String source, Exception e) {
456460

457461
@Override
458462
public void onNoLongerMaster(String source) {
459-
try {
463+
try (ThreadContext.StoredContext ignore = context.get()) {
460464
listener.onNoLongerMaster(source);
461465
} catch (Exception e) {
462466
logger.error(() -> new ParameterizedMessage(
@@ -466,7 +470,7 @@ public void onNoLongerMaster(String source) {
466470

467471
@Override
468472
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
469-
try {
473+
try (ThreadContext.StoredContext ignore = context.get()) {
470474
listener.clusterStateProcessed(source, oldState, newState);
471475
} catch (Exception e) {
472476
logger.error(() -> new ParameterizedMessage(
@@ -480,8 +484,9 @@ private static class SafeAckedClusterStateTaskListener extends SafeClusterStateT
480484
private final AckedClusterStateTaskListener listener;
481485
private final Logger logger;
482486

483-
SafeAckedClusterStateTaskListener(AckedClusterStateTaskListener listener, Logger logger) {
484-
super(listener, logger);
487+
SafeAckedClusterStateTaskListener(AckedClusterStateTaskListener listener, Supplier<ThreadContext.StoredContext> context,
488+
Logger logger) {
489+
super(listener, context, logger);
485490
this.listener = listener;
486491
this.logger = logger;
487492
}
@@ -493,7 +498,7 @@ public boolean mustAck(DiscoveryNode discoveryNode) {
493498

494499
@Override
495500
public void onAllNodesAcked(@Nullable Exception e) {
496-
try {
501+
try (ThreadContext.StoredContext ignore = context.get()) {
497502
listener.onAllNodesAcked(e);
498503
} catch (Exception inner) {
499504
inner.addSuppressed(e);
@@ -503,7 +508,7 @@ public void onAllNodesAcked(@Nullable Exception e) {
503508

504509
@Override
505510
public void onAckTimeout() {
506-
try {
511+
try (ThreadContext.StoredContext ignore = context.get()) {
507512
listener.onAckTimeout();
508513
} catch (Exception e) {
509514
logger.error("exception thrown by listener while notifying on ack timeout", e);
@@ -724,9 +729,13 @@ public <T> void submitStateUpdateTasks(final String source,
724729
if (!lifecycle.started()) {
725730
return;
726731
}
727-
try {
732+
final ThreadContext threadContext = threadPool.getThreadContext();
733+
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
734+
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
735+
threadContext.markAsSystemContext();
736+
728737
List<Batcher.UpdateTask> safeTasks = tasks.entrySet().stream()
729-
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue()), executor))
738+
.map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue(), supplier), executor))
730739
.collect(Collectors.toList());
731740
taskBatcher.submitTasks(safeTasks, config.timeout());
732741
} catch (EsRejectedExecutionException e) {

server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java

-2
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ public ClusterStateResponse newInstance() {
556556

557557
@Override
558558
public void handleResponse(ClusterStateResponse response) {
559-
assert transportService.getThreadPool().getThreadContext().isSystemContext() == false : "context is a system context";
560559
try {
561560
if (remoteClusterName.get() == null) {
562561
assert response.getClusterName().value() != null;
@@ -597,7 +596,6 @@ public void handleResponse(ClusterStateResponse response) {
597596

598597
@Override
599598
public void handleException(TransportException exp) {
600-
assert transportService.getThreadPool().getThreadContext().isSystemContext() == false : "context is a system context";
601599
logger.warn(() -> new ParameterizedMessage("fetching nodes from external cluster {} failed", clusterAlias), exp);
602600
try {
603601
IOUtils.closeWhileHandlingException(connection);

server/src/test/java/org/elasticsearch/cluster/service/MasterServiceTests.java

+82
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@
3434
import org.elasticsearch.cluster.block.ClusterBlocks;
3535
import org.elasticsearch.cluster.node.DiscoveryNode;
3636
import org.elasticsearch.cluster.node.DiscoveryNodes;
37+
import org.elasticsearch.common.Nullable;
3738
import org.elasticsearch.common.Priority;
3839
import org.elasticsearch.common.collect.Tuple;
3940
import org.elasticsearch.common.logging.Loggers;
4041
import org.elasticsearch.common.settings.Settings;
4142
import org.elasticsearch.common.unit.TimeValue;
4243
import org.elasticsearch.common.util.concurrent.BaseFuture;
44+
import org.elasticsearch.common.util.concurrent.ThreadContext;
4345
import org.elasticsearch.discovery.Discovery;
4446
import org.elasticsearch.test.ESTestCase;
4547
import org.elasticsearch.test.MockLogAppender;
@@ -52,6 +54,7 @@
5254
import org.junit.BeforeClass;
5355

5456
import java.util.ArrayList;
57+
import java.util.Collections;
5558
import java.util.HashMap;
5659
import java.util.HashSet;
5760
import java.util.List;
@@ -168,6 +171,85 @@ public void onFailure(String source, Exception e) {
168171
nonMaster.close();
169172
}
170173

174+
public void testThreadContext() throws InterruptedException {
175+
final TimedMasterService master = createTimedMasterService(true);
176+
final CountDownLatch latch = new CountDownLatch(1);
177+
178+
try (ThreadContext.StoredContext ignored = threadPool.getThreadContext().stashContext()) {
179+
final Map<String, String> expectedHeaders = Collections.singletonMap("test", "test");
180+
threadPool.getThreadContext().putHeader(expectedHeaders);
181+
182+
final TimeValue ackTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
183+
final TimeValue masterTimeout = randomBoolean() ? TimeValue.ZERO : TimeValue.timeValueMillis(randomInt(10000));
184+
185+
master.submitStateUpdateTask("test", new AckedClusterStateUpdateTask<Void>(null, null) {
186+
@Override
187+
public ClusterState execute(ClusterState currentState) {
188+
assertTrue(threadPool.getThreadContext().isSystemContext());
189+
assertEquals(Collections.emptyMap(), threadPool.getThreadContext().getHeaders());
190+
191+
if (randomBoolean()) {
192+
return ClusterState.builder(currentState).build();
193+
} else if (randomBoolean()) {
194+
return currentState;
195+
} else {
196+
throw new IllegalArgumentException("mock failure");
197+
}
198+
}
199+
200+
@Override
201+
public void onFailure(String source, Exception e) {
202+
assertFalse(threadPool.getThreadContext().isSystemContext());
203+
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
204+
latch.countDown();
205+
}
206+
207+
@Override
208+
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
209+
assertFalse(threadPool.getThreadContext().isSystemContext());
210+
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
211+
latch.countDown();
212+
}
213+
214+
@Override
215+
protected Void newResponse(boolean acknowledged) {
216+
return null;
217+
}
218+
219+
public TimeValue ackTimeout() {
220+
return ackTimeout;
221+
}
222+
223+
@Override
224+
public TimeValue timeout() {
225+
return masterTimeout;
226+
}
227+
228+
@Override
229+
public void onAllNodesAcked(@Nullable Exception e) {
230+
assertFalse(threadPool.getThreadContext().isSystemContext());
231+
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
232+
latch.countDown();
233+
}
234+
235+
@Override
236+
public void onAckTimeout() {
237+
assertFalse(threadPool.getThreadContext().isSystemContext());
238+
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
239+
latch.countDown();
240+
}
241+
242+
});
243+
244+
assertFalse(threadPool.getThreadContext().isSystemContext());
245+
assertEquals(expectedHeaders, threadPool.getThreadContext().getHeaders());
246+
}
247+
248+
latch.await();
249+
250+
master.close();
251+
}
252+
171253
/*
172254
* test that a listener throwing an exception while handling a
173255
* notification does not prevent publication notification to the

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java

+6-7
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import org.elasticsearch.common.io.stream.StreamInput;
2121
import org.elasticsearch.common.io.stream.StreamOutput;
2222
import org.elasticsearch.common.io.stream.Writeable;
23-
import org.elasticsearch.common.util.concurrent.ThreadContext;
2423
import org.elasticsearch.common.xcontent.ObjectParser;
2524
import org.elasticsearch.common.xcontent.ToXContent;
2625
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -293,7 +292,7 @@ public Builder deleteJob(String jobId, PersistentTasksCustomMetaData tasks) {
293292
return this;
294293
}
295294

296-
public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadContext) {
295+
public Builder putDatafeed(DatafeedConfig datafeedConfig, Map<String, String> headers) {
297296
if (datafeeds.containsKey(datafeedConfig.getId())) {
298297
throw new ResourceAlreadyExistsException("A datafeed with id [" + datafeedConfig.getId() + "] already exists");
299298
}
@@ -302,13 +301,13 @@ public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadCo
302301
Job job = jobs.get(jobId);
303302
DatafeedJobValidator.validate(datafeedConfig, job);
304303

305-
if (threadContext != null) {
304+
if (headers.isEmpty() == false) {
306305
// Adjust the request, adding security headers from the current thread context
307306
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig);
308-
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
307+
Map<String, String> securityHeaders = headers.entrySet().stream()
309308
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
310309
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
311-
builder.setHeaders(headers);
310+
builder.setHeaders(securityHeaders);
312311
datafeedConfig = builder.build();
313312
}
314313

@@ -328,15 +327,15 @@ private void checkJobIsAvailableForDatafeed(String jobId) {
328327
}
329328
}
330329

331-
public Builder updateDatafeed(DatafeedUpdate update, PersistentTasksCustomMetaData persistentTasks, ThreadContext threadContext) {
330+
public Builder updateDatafeed(DatafeedUpdate update, PersistentTasksCustomMetaData persistentTasks, Map<String, String> headers) {
332331
String datafeedId = update.getId();
333332
DatafeedConfig oldDatafeedConfig = datafeeds.get(datafeedId);
334333
if (oldDatafeedConfig == null) {
335334
throw ExceptionsHelper.missingDatafeedException(datafeedId);
336335
}
337336
checkDatafeedIsStopped(() -> Messages.getMessage(Messages.DATAFEED_CANNOT_UPDATE_IN_CURRENT_STATE, datafeedId,
338337
DatafeedState.STARTED), datafeedId, persistentTasks);
339-
DatafeedConfig newDatafeedConfig = update.apply(oldDatafeedConfig, threadContext);
338+
DatafeedConfig newDatafeedConfig = update.apply(oldDatafeedConfig, headers);
340339
if (newDatafeedConfig.getJobId().equals(oldDatafeedConfig.getJobId()) == false) {
341340
checkJobIsAvailableForDatafeed(newDatafeedConfig.getJobId());
342341
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdate.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
1414
import org.elasticsearch.common.unit.TimeValue;
15-
import org.elasticsearch.common.util.concurrent.ThreadContext;
1615
import org.elasticsearch.common.xcontent.ObjectParser;
1716
import org.elasticsearch.common.xcontent.ToXContentObject;
1817
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -264,7 +263,7 @@ ChunkingConfig getChunkingConfig() {
264263
* Applies the update to the given {@link DatafeedConfig}
265264
* @return a new {@link DatafeedConfig} that contains the update
266265
*/
267-
public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadContext) {
266+
public DatafeedConfig apply(DatafeedConfig datafeedConfig, Map<String, String> headers) {
268267
if (id.equals(datafeedConfig.getId()) == false) {
269268
throw new IllegalArgumentException("Cannot apply update to datafeedConfig with different id");
270269
}
@@ -301,12 +300,12 @@ public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadC
301300
builder.setChunkingConfig(chunkingConfig);
302301
}
303302

304-
if (threadContext != null) {
303+
if (headers.isEmpty() == false) {
305304
// Adjust the request, adding security headers from the current thread context
306-
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
305+
Map<String, String> securityHeaders = headers.entrySet().stream()
307306
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
308307
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
309-
builder.setHeaders(headers);
308+
builder.setHeaders(securityHeaders);
310309
}
311310

312311
return builder.build();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/datafeed/DatafeedUpdateTests.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ public void testApply_failBecauseTargetDatafeedHasDifferentId() {
114114

115115
public void testApply_givenEmptyUpdate() {
116116
DatafeedConfig datafeed = DatafeedConfigTests.createRandomizedDatafeedConfig("foo");
117-
DatafeedConfig updatedDatafeed = new DatafeedUpdate.Builder(datafeed.getId()).build().apply(datafeed, null);
117+
DatafeedConfig updatedDatafeed = new DatafeedUpdate.Builder(datafeed.getId()).build().apply(datafeed, Collections.emptyMap());
118118
assertThat(datafeed, equalTo(updatedDatafeed));
119119
}
120120

@@ -125,7 +125,7 @@ public void testApply_givenPartialUpdate() {
125125

126126
DatafeedUpdate.Builder updated = new DatafeedUpdate.Builder(datafeed.getId());
127127
updated.setScrollSize(datafeed.getScrollSize() + 1);
128-
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
128+
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());
129129

130130
DatafeedConfig.Builder expectedDatafeed = new DatafeedConfig.Builder(datafeed);
131131
expectedDatafeed.setScrollSize(datafeed.getScrollSize() + 1);
@@ -149,7 +149,7 @@ public void testApply_givenFullUpdateNoAggregations() {
149149
update.setScrollSize(8000);
150150
update.setChunkingConfig(ChunkingConfig.newManual(TimeValue.timeValueHours(1)));
151151

152-
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
152+
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());
153153

154154
assertThat(updatedDatafeed.getJobId(), equalTo("bar"));
155155
assertThat(updatedDatafeed.getIndices(), equalTo(Collections.singletonList("i_2")));
@@ -175,7 +175,7 @@ public void testApply_givenAggregations() {
175175
update.setAggregations(new AggregatorFactories.Builder().addAggregator(
176176
AggregationBuilders.histogram("a").interval(300000).field("time").subAggregation(maxTime)));
177177

178-
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, null);
178+
DatafeedConfig updatedDatafeed = update.build().apply(datafeed, Collections.emptyMap());
179179

180180
assertThat(updatedDatafeed.getIndices(), equalTo(Collections.singletonList("i_1")));
181181
assertThat(updatedDatafeed.getTypes(), equalTo(Collections.singletonList("t_1")));

0 commit comments

Comments
 (0)