Skip to content

Use a threadsafe map in SearchAsyncActionTests #34506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet;

public class SearchAsyncActionTests extends ESTestCase {

public void testSkipSearchShards() throws InterruptedException {
Expand Down Expand Up @@ -137,7 +141,7 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting sha
protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> results, SearchPhaseContext context) {
return new SearchPhase("test") {
@Override
public void run() throws IOException {
public void run() {
latch.countDown();
}
};
Expand Down Expand Up @@ -253,7 +257,6 @@ public void run() throws IOException {
public void testFanOutAndCollect() throws InterruptedException {
SearchRequest request = new SearchRequest();
request.setMaxConcurrentShardRequests(randomIntBetween(1, 100));
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<TestSearchResponse> response = new AtomicReference<>();
ActionListener<SearchResponse> responseListener = new ActionListener<SearchResponse>() {
@Override
Expand All @@ -270,7 +273,7 @@ public void onFailure(Exception e) {
DiscoveryNode primaryNode = new DiscoveryNode("node_1", new LocalTransportAddress("foo"), Version.CURRENT);
DiscoveryNode replicaNode = new DiscoveryNode("node_2", new LocalTransportAddress("bar"), Version.CURRENT);

Map<DiscoveryNode, Set<Long>> nodeToContextMap = new HashMap<>();
Map<DiscoveryNode, Set<Long>> nodeToContextMap = newConcurrentMap();
AtomicInteger contextIdGenerator = new AtomicInteger(0);
GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
new OriginalIndices(new String[]{"idx"}, IndicesOptions.strictExpandOpenAndForbidClosed()),
Expand All @@ -289,7 +292,9 @@ public void sendFreeContext(Transport.Connection connection, long contextId, Ori
lookup.put(replicaNode.getId(), new MockConnection(replicaNode));
Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
final ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
AbstractSearchAsyncAction asyncAction =
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean latchTriggered = new AtomicBoolean();
AbstractSearchAsyncAction<TestSearchPhaseResult> asyncAction =
new AbstractSearchAsyncAction<TestSearchPhaseResult>(
"test",
logger,
Expand Down Expand Up @@ -317,7 +322,7 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting sha
Transport.Connection connection = getConnection(null, shard.currentNodeId());
TestSearchPhaseResult testSearchPhaseResult = new TestSearchPhaseResult(contextIdGenerator.incrementAndGet(),
connection.getNode());
Set<Long> ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> new HashSet<>());
Set<Long> ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet());
ids.add(testSearchPhaseResult.getRequestId());
if (randomBoolean()) {
listener.onResponse(testSearchPhaseResult);
Expand All @@ -330,13 +335,16 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, ShardRouting sha
protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> results, SearchPhaseContext context) {
return new SearchPhase("test") {
@Override
public void run() throws IOException {
public void run() {
for (int i = 0; i < results.getNumShards(); i++) {
TestSearchPhaseResult result = results.getAtomicArray().get(i);
assertEquals(result.node.getId(), result.getSearchShardTarget().getNodeId());
sendReleaseSearchContext(result.getRequestId(), new MockConnection(result.node), OriginalIndices.NONE);
}
responseListener.onResponse(response);
if (latchTriggered.compareAndSet(false, true) == false) {
throw new AssertionError("latch triggered twice");
}
latch.countDown();
}
};
Expand Down