Skip to content

Commit b724619

Browse files
authored
Preserve thread context when connecting to remote cluster (#31574)
Establishing remote cluster connections uses a queue to coordinate multiple concurrent connect attempts. Connect attempts can be initiated by user triggered searches as well as by system events (e.g. when nodes disconnect). Multiple such concurrent events can lead to the connectListener of one event to be called under the thread context of another connect attempt. This can lead to the situation as seen in #31462 where the connect listener is executed under the system context, which breaks when fetching the search shards from the remote cluster. Closes #31462
1 parent 4fc833b commit b724619

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
3030
import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
3131
import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
32+
import org.elasticsearch.action.support.ContextPreservingActionListener;
3233
import org.elasticsearch.cluster.ClusterName;
3334
import org.elasticsearch.cluster.node.DiscoveryNode;
3435
import org.elasticsearch.cluster.node.DiscoveryNodes;
@@ -369,9 +370,11 @@ void forceConnect() {
369370
private void connect(ActionListener<Void> connectListener, boolean forceRun) {
370371
final boolean runConnect;
371372
final Collection<ActionListener<Void>> toNotify;
373+
final ActionListener<Void> listener = connectListener == null ? null :
374+
ContextPreservingActionListener.wrapPreservingContext(connectListener, transportService.getThreadPool().getThreadContext());
372375
synchronized (queue) {
373-
if (connectListener != null && queue.offer(connectListener) == false) {
374-
connectListener.onFailure(new RejectedExecutionException("connect queue is full"));
376+
if (listener != null && queue.offer(listener) == false) {
377+
listener.onFailure(new RejectedExecutionException("connect queue is full"));
375378
return;
376379
}
377380
if (forceRun == false && queue.isEmpty()) {

server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java

+59
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.elasticsearch.common.transport.TransportAddress;
4343
import org.elasticsearch.common.unit.TimeValue;
4444
import org.elasticsearch.common.util.CancellableThreads;
45+
import org.elasticsearch.common.util.concurrent.ThreadContext;
4546
import org.elasticsearch.common.xcontent.XContentBuilder;
4647
import org.elasticsearch.common.xcontent.XContentFactory;
4748
import org.elasticsearch.core.internal.io.IOUtils;
@@ -555,6 +556,64 @@ public void testFetchShards() throws Exception {
555556
}
556557
}
557558

559+
public void testFetchShardsThreadContextHeader() throws Exception {
560+
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
561+
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT);
562+
MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) {
563+
DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
564+
knownNodes.add(seedTransport.getLocalDiscoNode());
565+
knownNodes.add(discoverableTransport.getLocalDiscoNode());
566+
Collections.shuffle(knownNodes, random());
567+
try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) {
568+
service.start();
569+
service.acceptIncomingRequests();
570+
List<DiscoveryNode> nodes = Collections.singletonList(seedNode);
571+
try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
572+
nodes, service, Integer.MAX_VALUE, n -> true)) {
573+
SearchRequest request = new SearchRequest("test-index");
574+
Thread[] threads = new Thread[10];
575+
for (int i = 0; i < threads.length; i++) {
576+
final String threadId = Integer.toString(i);
577+
threads[i] = new Thread(() -> {
578+
ThreadContext threadContext = seedTransport.threadPool.getThreadContext();
579+
threadContext.putHeader("threadId", threadId);
580+
AtomicReference<ClusterSearchShardsResponse> reference = new AtomicReference<>();
581+
AtomicReference<Exception> failReference = new AtomicReference<>();
582+
final ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest("test-index")
583+
.indicesOptions(request.indicesOptions()).local(true).preference(request.preference())
584+
.routing(request.routing());
585+
CountDownLatch responseLatch = new CountDownLatch(1);
586+
connection.fetchSearchShards(searchShardsRequest,
587+
new LatchedActionListener<>(ActionListener.wrap(
588+
resp -> {
589+
reference.set(resp);
590+
assertEquals(threadId, seedTransport.threadPool.getThreadContext().getHeader("threadId"));
591+
},
592+
failReference::set), responseLatch));
593+
try {
594+
responseLatch.await();
595+
} catch (InterruptedException e) {
596+
throw new RuntimeException(e);
597+
}
598+
assertNull(failReference.get());
599+
assertNotNull(reference.get());
600+
ClusterSearchShardsResponse clusterSearchShardsResponse = reference.get();
601+
assertEquals(knownNodes, Arrays.asList(clusterSearchShardsResponse.getNodes()));
602+
});
603+
}
604+
for (int i = 0; i < threads.length; i++) {
605+
threads[i].start();
606+
}
607+
608+
for (int i = 0; i < threads.length; i++) {
609+
threads[i].join();
610+
}
611+
assertTrue(connection.assertNoRunningConnections());
612+
}
613+
}
614+
}
615+
}
616+
558617
public void testFetchShardsSkipUnavailable() throws Exception {
559618
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
560619
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT)) {

0 commit comments

Comments
 (0)