diff --git a/core/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java b/core/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java index 1c0b6c2840ae9..9e1b619019cb6 100644 --- a/core/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java +++ b/core/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java @@ -53,9 +53,13 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; +import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet; + public class SearchAsyncActionTests extends ESTestCase { public void testSkipSearchShards() throws InterruptedException { @@ -137,7 +141,7 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting sha protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { @Override - public void run() throws IOException { + public void run() { latch.countDown(); } }; @@ -253,7 +257,6 @@ public void run() throws IOException { public void testFanOutAndCollect() throws InterruptedException { SearchRequest request = new SearchRequest(); request.setMaxConcurrentShardRequests(randomIntBetween(1, 100)); - CountDownLatch latch = new CountDownLatch(1); AtomicReference response = new AtomicReference<>(); ActionListener responseListener = new ActionListener() { @Override @@ -270,7 +273,7 @@ public void onFailure(Exception e) { DiscoveryNode primaryNode = new DiscoveryNode("node_1", new LocalTransportAddress("foo"), Version.CURRENT); DiscoveryNode replicaNode = new DiscoveryNode("node_2", new LocalTransportAddress("bar"), Version.CURRENT); - Map> nodeToContextMap = new HashMap<>(); + Map> nodeToContextMap = newConcurrentMap(); AtomicInteger contextIdGenerator = new AtomicInteger(0); GroupShardsIterator shardsIter = getShardsIter("idx", new OriginalIndices(new String[]{"idx"}, IndicesOptions.strictExpandOpenAndForbidClosed()), @@ -289,7 +292,9 @@ public void sendFreeContext(Transport.Connection connection, long contextId, Ori lookup.put(replicaNode.getId(), new MockConnection(replicaNode)); Map aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)); final ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors())); - AbstractSearchAsyncAction asyncAction = + final CountDownLatch latch = new CountDownLatch(1); + final AtomicBoolean latchTriggered = new AtomicBoolean(); + AbstractSearchAsyncAction asyncAction = new AbstractSearchAsyncAction( "test", logger, @@ -317,7 +322,7 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting sha Transport.Connection connection = getConnection(null, shard.currentNodeId()); TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(), connection.getNode()); - Set ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> new HashSet<>()); + Set ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet()); ids.add(testSearchPhaseResult.getRequestId()); if (randomBoolean()) { listener.onResponse(testSearchPhaseResult); @@ -330,13 +335,16 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting sha protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { @Override - public void run() throws IOException { + public void run() { for (int i = 0; i < results.getNumShards(); i++) { TestSearchPhaseResult result = results.getAtomicArray().get(i); assertEquals(result.node.getId(), result.getSearchShardTarget().getNodeId()); sendReleaseSearchContext(result.getRequestId(), new MockConnection(result.node), OriginalIndices.NONE); } responseListener.onResponse(response); + if (latchTriggered.compareAndSet(false, true) == false) { + throw new AssertionError("latch triggered twice"); + } latch.countDown(); } };