Skip to content

Commit 13e2004

Browse files
authored
Executes incremental reduce in the search thread pool (#58461)
This change forks the execution of partial reduces in the coordinating node to the search thread pool. It also ensures that partial reduces are executed sequentially and asynchronously in order to limit the memory and cpu that a single search request can use but also to avoid blocking a network thread. If a partial reduce fails with an exception, the search request is cancelled and the reporting of the error is delayed to the start of the fetch phase (when the final reduce is performed). This ensures that we cleanup the in-flight search requests before returning an error to the user. Closes #53411 Relates #51857
1 parent 1aa070f commit 13e2004

24 files changed

+832
-497
lines changed

server/src/internalClusterTest/java/org/elasticsearch/action/RejectionActionIT.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import java.util.concurrent.CopyOnWriteArrayList;
3535
import java.util.concurrent.CountDownLatch;
3636

37+
import static org.hamcrest.Matchers.anyOf;
38+
import static org.hamcrest.Matchers.containsString;
3739
import static org.hamcrest.Matchers.equalTo;
3840

3941
@ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 2)
@@ -85,17 +87,17 @@ public void onFailure(Exception e) {
8587
if (response instanceof SearchResponse) {
8688
SearchResponse searchResponse = (SearchResponse) response;
8789
for (ShardSearchFailure failure : searchResponse.getShardFailures()) {
88-
assertTrue("got unexpected reason..." + failure.reason(),
89-
failure.reason().toLowerCase(Locale.ENGLISH).contains("rejected"));
90+
assertThat(failure.reason().toLowerCase(Locale.ENGLISH),
91+
anyOf(containsString("cancelled"), containsString("rejected")));
9092
}
9193
} else {
9294
Exception t = (Exception) response;
9395
Throwable unwrap = ExceptionsHelper.unwrapCause(t);
9496
if (unwrap instanceof SearchPhaseExecutionException) {
9597
SearchPhaseExecutionException e = (SearchPhaseExecutionException) unwrap;
9698
for (ShardSearchFailure failure : e.shardFailures()) {
97-
assertTrue("got unexpected reason..." + failure.reason(),
98-
failure.reason().toLowerCase(Locale.ENGLISH).contains("rejected"));
99+
assertThat(failure.reason().toLowerCase(Locale.ENGLISH),
100+
anyOf(containsString("cancelled"), containsString("rejected")));
99101
}
100102
} else if ((unwrap instanceof EsRejectedExecutionException) == false) {
101103
throw new AssertionError("unexpected failure", (Throwable) response);

server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ public void testSearchProgressWithShardSort() throws Exception {
131131
testCase((NodeClient) client(), request, sortShards, false);
132132
}
133133

134-
private static void testCase(NodeClient client, SearchRequest request,
135-
List<SearchShard> expectedShards, boolean hasFetchPhase) throws InterruptedException {
134+
private void testCase(NodeClient client, SearchRequest request,
135+
List<SearchShard> expectedShards, boolean hasFetchPhase) throws InterruptedException {
136136
AtomicInteger numQueryResults = new AtomicInteger();
137137
AtomicInteger numQueryFailures = new AtomicInteger();
138138
AtomicInteger numFetchResults = new AtomicInteger();
@@ -204,7 +204,6 @@ public SearchTask createTask(long id, String type, String action, TaskId parentT
204204
}
205205
}, listener);
206206
latch.await();
207-
208207
assertThat(shardsListener.get(), equalTo(expectedShards));
209208
assertThat(numQueryResults.get(), equalTo(searchResponse.get().getSuccessfulShards()));
210209
assertThat(numQueryFailures.get(), equalTo(searchResponse.get().getFailedShards()));

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -470,12 +470,15 @@ public final void onShardFailure(final int shardIndex, @Nullable SearchShardTarg
470470
protected void onShardResult(Result result, SearchShardIterator shardIt) {
471471
assert result.getShardIndex() != -1 : "shard index is not set";
472472
assert result.getSearchShardTarget() != null : "search shard target must not be null";
473-
successfulOps.incrementAndGet();
474-
results.consumeResult(result);
475473
hasShardResponse.set(true);
476474
if (logger.isTraceEnabled()) {
477475
logger.trace("got first-phase result from {}", result != null ? result.getSearchShardTarget() : null);
478476
}
477+
results.consumeResult(result, () -> onShardResultConsumed(result, shardIt));
478+
}
479+
480+
private void onShardResultConsumed(Result result, SearchShardIterator shardIt) {
481+
successfulOps.incrementAndGet();
479482
// clean a previous error on this shard group (note, this code will be serialized on the same shardIndex value level
480483
// so its ok concurrency wise to miss potentially the shard failures being created because of another failure
481484
// in the #addShardFailure, because by definition, it will happen on *another* shardIndex

server/src/main/java/org/elasticsearch/action/search/ArraySearchPhaseResults.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ Stream<Result> getSuccessfulResults() {
3939
return results.asList().stream();
4040
}
4141

42-
void consumeResult(Result result) {
42+
@Override
43+
void consumeResult(Result result, Runnable next) {
4344
assert results.get(result.getShardIndex()) == null : "shardIndex: " + result.getShardIndex() + " is already set";
4445
results.set(result.getShardIndex(), result);
46+
next.run();
4547
}
4648

4749
boolean hasResult(int shardIndex) {

server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,12 @@ private static final class CanMatchSearchPhaseResults extends SearchPhaseResults
159159
}
160160

161161
@Override
162-
void consumeResult(CanMatchResponse result) {
163-
consumeResult(result.getShardIndex(), result.canMatch(), result.estimatedMinAndMax());
162+
void consumeResult(CanMatchResponse result, Runnable next) {
163+
try {
164+
consumeResult(result.getShardIndex(), result.canMatch(), result.estimatedMinAndMax());
165+
} finally {
166+
next.run();
167+
}
164168
}
165169

166170
@Override

server/src/main/java/org/elasticsearch/action/search/CountedCollector.java

+3-9
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,18 @@
2323
import org.elasticsearch.search.SearchPhaseResult;
2424
import org.elasticsearch.search.SearchShardTarget;
2525

26-
import java.util.function.Consumer;
27-
2826
/**
2927
* This is a simple base class to simplify fan out to shards and collect their results. Each results passed to
3028
* {@link #onResult(SearchPhaseResult)} will be set to the provided result array
3129
* where the given index is used to set the result on the array.
3230
*/
3331
final class CountedCollector<R extends SearchPhaseResult> {
34-
private final Consumer<R> resultConsumer;
32+
private final ArraySearchPhaseResults<R> resultConsumer;
3533
private final CountDown counter;
3634
private final Runnable onFinish;
3735
private final SearchPhaseContext context;
3836

39-
CountedCollector(Consumer<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
37+
CountedCollector(ArraySearchPhaseResults<R> resultConsumer, int expectedOps, Runnable onFinish, SearchPhaseContext context) {
4038
this.resultConsumer = resultConsumer;
4139
this.counter = new CountDown(expectedOps);
4240
this.onFinish = onFinish;
@@ -58,11 +56,7 @@ void countDown() {
5856
* Sets the result to the given array index and then runs {@link #countDown()}
5957
*/
6058
void onResult(R result) {
61-
try {
62-
resultConsumer.accept(result);
63-
} finally {
64-
countDown();
65-
}
59+
resultConsumer.consumeResult(result, this::countDown);
6660
}
6761

6862
/**

server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java

+5-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import java.io.IOException;
3232
import java.util.List;
33+
import java.util.function.Consumer;
3334
import java.util.function.Function;
3435

3536
/**
@@ -51,10 +52,11 @@ final class DfsQueryPhase extends SearchPhase {
5152
DfsQueryPhase(AtomicArray<DfsSearchResult> dfsSearchResults,
5253
SearchPhaseController searchPhaseController,
5354
Function<ArraySearchPhaseResults<SearchPhaseResult>, SearchPhase> nextPhaseFactory,
54-
SearchPhaseContext context) {
55+
SearchPhaseContext context, Consumer<Exception> onPartialMergeFailure) {
5556
super("dfs_query");
5657
this.progressListener = context.getTask().getProgressListener();
57-
this.queryResult = searchPhaseController.newSearchPhaseResults(progressListener, context.getRequest(), context.getNumShards());
58+
this.queryResult = searchPhaseController.newSearchPhaseResults(context, progressListener,
59+
context.getRequest(), context.getNumShards(), onPartialMergeFailure);
5860
this.searchPhaseController = searchPhaseController;
5961
this.dfsSearchResults = dfsSearchResults;
6062
this.nextPhaseFactory = nextPhaseFactory;
@@ -68,7 +70,7 @@ public void run() throws IOException {
6870
// to free up memory early
6971
final List<DfsSearchResult> resultList = dfsSearchResults.asList();
7072
final AggregatedDfs dfs = searchPhaseController.aggregateDfs(resultList);
71-
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(queryResult::consumeResult,
73+
final CountedCollector<SearchPhaseResult> counter = new CountedCollector<>(queryResult,
7274
resultList.size(),
7375
() -> context.executeNextPhase(this, nextPhaseFactory.apply(queryResult)), context);
7476
for (final DfsSearchResult dfsResult : resultList) {

server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

+5-6
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.search.query.QuerySearchResult;
3737
import org.elasticsearch.transport.Transport;
3838

39-
import java.io.IOException;
4039
import java.util.List;
4140
import java.util.function.BiFunction;
4241

@@ -45,7 +44,7 @@
4544
* Then it reaches out to all relevant shards to fetch the topN hits.
4645
*/
4746
final class FetchSearchPhase extends SearchPhase {
48-
private final AtomicArray<FetchSearchResult> fetchResults;
47+
private final ArraySearchPhaseResults<FetchSearchResult> fetchResults;
4948
private final SearchPhaseController searchPhaseController;
5049
private final AtomicArray<SearchPhaseResult> queryResults;
5150
private final BiFunction<InternalSearchResponse, String, SearchPhase> nextPhaseFactory;
@@ -73,7 +72,7 @@ final class FetchSearchPhase extends SearchPhase {
7372
throw new IllegalStateException("number of shards must match the length of the query results but doesn't:"
7473
+ context.getNumShards() + "!=" + resultConsumer.getNumShards());
7574
}
76-
this.fetchResults = new AtomicArray<>(resultConsumer.getNumShards());
75+
this.fetchResults = new ArraySearchPhaseResults<>(resultConsumer.getNumShards());
7776
this.searchPhaseController = searchPhaseController;
7877
this.queryResults = resultConsumer.getAtomicArray();
7978
this.nextPhaseFactory = nextPhaseFactory;
@@ -102,7 +101,7 @@ public void onFailure(Exception e) {
102101
});
103102
}
104103

105-
private void innerRun() throws IOException {
104+
private void innerRun() throws Exception {
106105
final int numShards = context.getNumShards();
107106
final boolean isScrollSearch = context.getRequest().scroll() != null;
108107
final List<SearchPhaseResult> phaseResults = queryResults.asList();
@@ -117,7 +116,7 @@ private void innerRun() throws IOException {
117116
final boolean queryAndFetchOptimization = queryResults.length() == 1;
118117
final Runnable finishPhase = ()
119118
-> moveToNextPhase(searchPhaseController, scrollId, reducedQueryPhase, queryAndFetchOptimization ?
120-
queryResults : fetchResults);
119+
queryResults : fetchResults.getAtomicArray());
121120
if (queryAndFetchOptimization) {
122121
assert phaseResults.isEmpty() || phaseResults.get(0).fetchResult() != null : "phaseResults empty [" + phaseResults.isEmpty()
123122
+ "], single result: " + phaseResults.get(0).fetchResult();
@@ -137,7 +136,7 @@ private void innerRun() throws IOException {
137136
final ScoreDoc[] lastEmittedDocPerShard = isScrollSearch ?
138137
searchPhaseController.getLastEmittedDocPerShard(reducedQueryPhase, numShards)
139138
: null;
140-
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(r -> fetchResults.set(r.getShardIndex(), r),
139+
final CountedCollector<FetchSearchResult> counter = new CountedCollector<>(fetchResults,
141140
docIdsToLoad.length, // we count down every shard in the result no matter if we got any results or not
142141
finishPhase, context);
143142
for (int i = 0; i < docIdsToLoad.length; i++) {

0 commit comments

Comments
 (0)