Skip to content

Commit 263defd

Browse files
committed
Preserve thread context when connecting to remote cluster (#31574)
Establishing remote cluster connections uses a queue to coordinate multiple concurrent connect attempts. Connect attempts can be initiated by user triggered searches as well as by system events (e.g. when nodes disconnect). Multiple such concurrent events can lead to the connectListener of one event to be called under the thread context of another connect attempt. This can lead to the situation as seen in #31462 where the connect listener is executed under the system context, which breaks when fetching the search shards from the remote cluster. Closes #31462
1 parent 3306420 commit 263defd

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.elasticsearch.action.admin.cluster.state.ClusterStateAction;
3535
import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest;
3636
import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
37+
import org.elasticsearch.action.support.ContextPreservingActionListener;
3738
import org.elasticsearch.cluster.ClusterName;
3839
import org.elasticsearch.cluster.node.DiscoveryNode;
3940
import org.elasticsearch.cluster.node.DiscoveryNodes;
@@ -362,9 +363,11 @@ void forceConnect() {
362363
private void connect(ActionListener<Void> connectListener, boolean forceRun) {
363364
final boolean runConnect;
364365
final Collection<ActionListener<Void>> toNotify;
366+
final ActionListener<Void> listener = connectListener == null ? null :
367+
ContextPreservingActionListener.wrapPreservingContext(connectListener, transportService.getThreadPool().getThreadContext());
365368
synchronized (queue) {
366-
if (connectListener != null && queue.offer(connectListener) == false) {
367-
connectListener.onFailure(new RejectedExecutionException("connect queue is full"));
369+
if (listener != null && queue.offer(listener) == false) {
370+
listener.onFailure(new RejectedExecutionException("connect queue is full"));
368371
return;
369372
}
370373
if (forceRun == false && queue.isEmpty()) {

server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java

+59
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.elasticsearch.common.transport.TransportAddress;
4949
import org.elasticsearch.common.unit.TimeValue;
5050
import org.elasticsearch.common.util.CancellableThreads;
51+
import org.elasticsearch.common.util.concurrent.ThreadContext;
5152
import org.elasticsearch.common.xcontent.XContentBuilder;
5253
import org.elasticsearch.common.xcontent.XContentFactory;
5354
import org.elasticsearch.core.internal.io.IOUtils;
@@ -561,6 +562,64 @@ public void testFetchShards() throws Exception {
561562
}
562563
}
563564

565+
public void testFetchShardsThreadContextHeader() throws Exception {
566+
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
567+
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT);
568+
MockTransportService discoverableTransport = startTransport("discoverable_node", knownNodes, Version.CURRENT)) {
569+
DiscoveryNode seedNode = seedTransport.getLocalDiscoNode();
570+
knownNodes.add(seedTransport.getLocalDiscoNode());
571+
knownNodes.add(discoverableTransport.getLocalDiscoNode());
572+
Collections.shuffle(knownNodes, random());
573+
try (MockTransportService service = MockTransportService.createNewService(Settings.EMPTY, Version.CURRENT, threadPool, null)) {
574+
service.start();
575+
service.acceptIncomingRequests();
576+
List<DiscoveryNode> nodes = Collections.singletonList(seedNode);
577+
try (RemoteClusterConnection connection = new RemoteClusterConnection(Settings.EMPTY, "test-cluster",
578+
nodes, service, Integer.MAX_VALUE, n -> true)) {
579+
SearchRequest request = new SearchRequest("test-index");
580+
Thread[] threads = new Thread[10];
581+
for (int i = 0; i < threads.length; i++) {
582+
final String threadId = Integer.toString(i);
583+
threads[i] = new Thread(() -> {
584+
ThreadContext threadContext = seedTransport.threadPool.getThreadContext();
585+
threadContext.putHeader("threadId", threadId);
586+
AtomicReference<ClusterSearchShardsResponse> reference = new AtomicReference<>();
587+
AtomicReference<Exception> failReference = new AtomicReference<>();
588+
final ClusterSearchShardsRequest searchShardsRequest = new ClusterSearchShardsRequest("test-index")
589+
.indicesOptions(request.indicesOptions()).local(true).preference(request.preference())
590+
.routing(request.routing());
591+
CountDownLatch responseLatch = new CountDownLatch(1);
592+
connection.fetchSearchShards(searchShardsRequest,
593+
new LatchedActionListener<>(ActionListener.wrap(
594+
resp -> {
595+
reference.set(resp);
596+
assertEquals(threadId, seedTransport.threadPool.getThreadContext().getHeader("threadId"));
597+
},
598+
failReference::set), responseLatch));
599+
try {
600+
responseLatch.await();
601+
} catch (InterruptedException e) {
602+
throw new RuntimeException(e);
603+
}
604+
assertNull(failReference.get());
605+
assertNotNull(reference.get());
606+
ClusterSearchShardsResponse clusterSearchShardsResponse = reference.get();
607+
assertEquals(knownNodes, Arrays.asList(clusterSearchShardsResponse.getNodes()));
608+
});
609+
}
610+
for (int i = 0; i < threads.length; i++) {
611+
threads[i].start();
612+
}
613+
614+
for (int i = 0; i < threads.length; i++) {
615+
threads[i].join();
616+
}
617+
assertTrue(connection.assertNoRunningConnections());
618+
}
619+
}
620+
}
621+
}
622+
564623
public void testFetchShardsSkipUnavailable() throws Exception {
565624
List<DiscoveryNode> knownNodes = new CopyOnWriteArrayList<>();
566625
try (MockTransportService seedTransport = startTransport("seed_node", knownNodes, Version.CURRENT)) {

0 commit comments

Comments
 (0)