Skip to content

Commit 637f081

Browse files
authored
Fix possible NPE on search phase failure (#57952)
When a search phase fails, we release the context of all successful shards. Successful shards that rewrite the request to match none will not create any context since #. This change ensures that we don't try to release a `null` context on these successful shards. Closes #57945
1 parent e990c0a commit 637f081

File tree

4 files changed

+122
-8
lines changed

4 files changed

+122
-8
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -559,13 +559,15 @@ public final void onPhaseFailure(SearchPhase phase, String msg, Throwable cause)
559559
*/
560560
private void raisePhaseFailure(SearchPhaseExecutionException exception) {
561561
results.getSuccessfulResults().forEach((entry) -> {
562-
try {
563-
SearchShardTarget searchShardTarget = entry.getSearchShardTarget();
564-
Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
565-
sendReleaseSearchContext(entry.getContextId(), connection, searchShardTarget.getOriginalIndices());
566-
} catch (Exception inner) {
567-
inner.addSuppressed(exception);
568-
logger.trace("failed to release context", inner);
562+
if (entry.getContextId() != null) {
563+
try {
564+
SearchShardTarget searchShardTarget = entry.getSearchShardTarget();
565+
Transport.Connection connection = getConnection(searchShardTarget.getClusterAlias(), searchShardTarget.getNodeId());
566+
sendReleaseSearchContext(entry.getContextId(), connection, searchShardTarget.getOriginalIndices());
567+
} catch (Exception inner) {
568+
inner.addSuppressed(exception);
569+
logger.trace("failed to release context", inner);
570+
}
569571
}
570572
});
571573
listener.onFailure(exception);

server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import java.io.IOException;
5959
import java.util.HashMap;
6060
import java.util.Map;
61+
import java.util.Objects;
6162
import java.util.function.BiFunction;
6263

6364
/**
@@ -199,7 +200,7 @@ static class ScrollFreeContextRequest extends TransportRequest {
199200
private SearchContextId contextId;
200201

201202
ScrollFreeContextRequest(SearchContextId contextId) {
202-
this.contextId = contextId;
203+
this.contextId = Objects.requireNonNull(contextId);
203204
}
204205

205206
ScrollFreeContextRequest(StreamInput in) throws IOException {

server/src/main/java/org/elasticsearch/search/SearchPhaseResult.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.elasticsearch.search;
2121

22+
import org.elasticsearch.common.Nullable;
2223
import org.elasticsearch.common.io.stream.StreamInput;
2324
import org.elasticsearch.common.io.stream.StreamOutput;
2425
import org.elasticsearch.search.fetch.FetchSearchResult;
@@ -52,7 +53,9 @@ protected SearchPhaseResult(StreamInput in) throws IOException {
5253

5354
/**
5455
* Returns the search context ID that is used to reference the search context on the executing node
56+
* or <code>null</code> if no context was created.
5557
*/
58+
@Nullable
5659
public SearchContextId getContextId() {
5760
return contextId;
5861
}

server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959

6060
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap;
6161
import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentSet;
62+
import static org.hamcrest.Matchers.containsString;
6263
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
6364

6465
public class SearchAsyncActionTests extends ESTestCase {
@@ -376,6 +377,113 @@ protected void executeNext(Runnable runnable, Thread originalThread) {
376377
executor.shutdown();
377378
}
378379

380+
public void testFanOutAndFail() throws InterruptedException {
381+
SearchRequest request = new SearchRequest();
382+
request.allowPartialSearchResults(true);
383+
request.setMaxConcurrentShardRequests(randomIntBetween(1, 100));
384+
CountDownLatch latch = new CountDownLatch(1);
385+
AtomicReference<Exception> failure = new AtomicReference<>();
386+
ActionListener<SearchResponse> responseListener = ActionListener.wrap(
387+
searchResponse -> { throw new AssertionError("unexpected response"); },
388+
exc -> {
389+
failure.set(exc);
390+
latch.countDown();
391+
});
392+
DiscoveryNode primaryNode = new DiscoveryNode("node_1", buildNewFakeTransportAddress(), Version.CURRENT);
393+
DiscoveryNode replicaNode = new DiscoveryNode("node_2", buildNewFakeTransportAddress(), Version.CURRENT);
394+
395+
Map<DiscoveryNode, Set<SearchContextId>> nodeToContextMap = newConcurrentMap();
396+
AtomicInteger contextIdGenerator = new AtomicInteger(0);
397+
int numShards = randomIntBetween(2, 10);
398+
GroupShardsIterator<SearchShardIterator> shardsIter = getShardsIter("idx",
399+
new OriginalIndices(new String[]{"idx"}, SearchRequest.DEFAULT_INDICES_OPTIONS),
400+
numShards, randomBoolean(), primaryNode, replicaNode);
401+
AtomicInteger numFreedContext = new AtomicInteger();
402+
SearchTransportService transportService = new SearchTransportService(null, null) {
403+
@Override
404+
public void sendFreeContext(Transport.Connection connection, SearchContextId contextId, OriginalIndices originalIndices) {
405+
assertNotNull(contextId);
406+
numFreedContext.incrementAndGet();
407+
assertTrue(nodeToContextMap.containsKey(connection.getNode()));
408+
assertTrue(nodeToContextMap.get(connection.getNode()).remove(contextId));
409+
}
410+
};
411+
Map<String, Transport.Connection> lookup = new HashMap<>();
412+
lookup.put(primaryNode.getId(), new MockConnection(primaryNode));
413+
lookup.put(replicaNode.getId(), new MockConnection(replicaNode));
414+
Map<String, AliasFilter> aliasFilters = Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY));
415+
ExecutorService executor = Executors.newFixedThreadPool(randomIntBetween(1, Runtime.getRuntime().availableProcessors()));
416+
AbstractSearchAsyncAction<TestSearchPhaseResult> asyncAction =
417+
new AbstractSearchAsyncAction<TestSearchPhaseResult>(
418+
"test",
419+
logger,
420+
transportService,
421+
(cluster, node) -> {
422+
assert cluster == null : "cluster was not null: " + cluster;
423+
return lookup.get(node); },
424+
aliasFilters,
425+
Collections.emptyMap(),
426+
Collections.emptyMap(),
427+
executor,
428+
request,
429+
responseListener,
430+
shardsIter,
431+
new TransportSearchAction.SearchTimeProvider(0, 0, () -> 0),
432+
ClusterState.EMPTY_STATE,
433+
null,
434+
new ArraySearchPhaseResults<>(shardsIter.size()),
435+
request.getMaxConcurrentShardRequests(),
436+
SearchResponse.Clusters.EMPTY) {
437+
TestSearchResponse response = new TestSearchResponse();
438+
439+
@Override
440+
protected void executePhaseOnShard(SearchShardIterator shardIt,
441+
ShardRouting shard,
442+
SearchActionListener<TestSearchPhaseResult> listener) {
443+
assertTrue("shard: " + shard.shardId() + " has been queried twice", response.queried.add(shard.shardId()));
444+
Transport.Connection connection = getConnection(null, shard.currentNodeId());
445+
final TestSearchPhaseResult testSearchPhaseResult;
446+
if (shard.shardId().id() == 0) {
447+
testSearchPhaseResult = new TestSearchPhaseResult(null, connection.getNode());
448+
} else {
449+
testSearchPhaseResult = new TestSearchPhaseResult(new SearchContextId(UUIDs.randomBase64UUID(),
450+
contextIdGenerator.incrementAndGet()), connection.getNode());
451+
Set<SearchContextId> ids = nodeToContextMap.computeIfAbsent(connection.getNode(), (n) -> newConcurrentSet());
452+
ids.add(testSearchPhaseResult.getContextId());
453+
}
454+
if (randomBoolean()) {
455+
listener.onResponse(testSearchPhaseResult);
456+
} else {
457+
new Thread(() -> listener.onResponse(testSearchPhaseResult)).start();
458+
}
459+
}
460+
461+
@Override
462+
protected SearchPhase getNextPhase(SearchPhaseResults<TestSearchPhaseResult> results,
463+
SearchPhaseContext context) {
464+
return new SearchPhase("test") {
465+
@Override
466+
public void run() {
467+
throw new RuntimeException("boom");
468+
}
469+
};
470+
}
471+
};
472+
asyncAction.start();
473+
latch.await();
474+
assertNotNull(failure.get());
475+
assertThat(failure.get().getCause().getMessage(), containsString("boom"));
476+
assertFalse(nodeToContextMap.isEmpty());
477+
assertTrue(nodeToContextMap.toString(), nodeToContextMap.containsKey(primaryNode) || nodeToContextMap.containsKey(replicaNode));
478+
assertEquals(shardsIter.size()-1, numFreedContext.get());
479+
if (nodeToContextMap.containsKey(primaryNode)) {
480+
assertTrue(nodeToContextMap.get(primaryNode).toString(), nodeToContextMap.get(primaryNode).isEmpty());
481+
} else {
482+
assertTrue(nodeToContextMap.get(replicaNode).toString(), nodeToContextMap.get(replicaNode).isEmpty());
483+
}
484+
executor.shutdown();
485+
}
486+
379487
public void testAllowPartialResults() throws InterruptedException {
380488
SearchRequest request = new SearchRequest();
381489
request.allowPartialSearchResults(false);

0 commit comments

Comments
 (0)