Skip to content

Commit 53e3677

Browse files
authored
QueryPhaseResultConsumer to call notifyPartialReduce (#62083)
As part of #60275 QueryPhaseResultConsumer ended up calling SearchProgressListener#onPartialReduce directly instead of notifyPartialReduce. That means we don't catch exceptions that may occur while executing the progress listener callback. This commit fixes the call and adds a test for this scenario.
1 parent 06c85d3 commit 53e3677

File tree

2 files changed

+157
-2
lines changed

2 files changed

+157
-2
lines changed

server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ private MergeResult partialReduce(MergeTask task,
192192
SearchShardTarget target = result.getSearchShardTarget();
193193
processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
194194
}
195-
progressListener.onPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
195+
progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
196196
return new MergeResult(processedShards, newTopDocs, newAggs);
197197
}
198198

@@ -281,7 +281,7 @@ private void onMergeFailure(Exception exc) {
281281
if (task != null) {
282282
toCancel.add(task);
283283
}
284-
queue.stream().forEach(toCancel::add);
284+
toCancel.addAll(queue);
285285
queue.clear();
286286
mergeResult = null;
287287
toCancel.stream().forEach(MergeTask::cancel);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.action.search;
21+
22+
import org.apache.lucene.search.ScoreDoc;
23+
import org.apache.lucene.search.TopDocs;
24+
import org.apache.lucene.search.TotalHits;
25+
import org.elasticsearch.action.OriginalIndices;
26+
import org.elasticsearch.common.io.stream.DelayableWriteable;
27+
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
28+
import org.elasticsearch.common.util.BigArrays;
29+
import org.elasticsearch.common.util.concurrent.EsExecutors;
30+
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
31+
import org.elasticsearch.index.shard.ShardId;
32+
import org.elasticsearch.search.DocValueFormat;
33+
import org.elasticsearch.search.SearchShardTarget;
34+
import org.elasticsearch.search.aggregations.InternalAggregation;
35+
import org.elasticsearch.search.aggregations.InternalAggregations;
36+
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
37+
import org.elasticsearch.search.query.QuerySearchResult;
38+
import org.elasticsearch.test.ESTestCase;
39+
import org.elasticsearch.threadpool.TestThreadPool;
40+
import org.elasticsearch.threadpool.ThreadPool;
41+
import org.junit.After;
42+
import org.junit.Before;
43+
44+
import java.util.ArrayList;
45+
import java.util.Collections;
46+
import java.util.List;
47+
import java.util.concurrent.CountDownLatch;
48+
import java.util.concurrent.TimeUnit;
49+
import java.util.concurrent.atomic.AtomicInteger;
50+
import java.util.concurrent.atomic.AtomicReference;
51+
52+
public class QueryPhaseResultConsumerTests extends ESTestCase {
53+
54+
private SearchPhaseController searchPhaseController;
55+
private ThreadPool threadPool;
56+
private EsThreadPoolExecutor executor;
57+
58+
@Before
59+
public void setup() {
60+
searchPhaseController = new SearchPhaseController(writableRegistry(),
61+
s -> new InternalAggregation.ReduceContextBuilder() {
62+
@Override
63+
public InternalAggregation.ReduceContext forPartialReduction() {
64+
return InternalAggregation.ReduceContext.forPartialReduction(
65+
BigArrays.NON_RECYCLING_INSTANCE, null, () -> PipelineAggregator.PipelineTree.EMPTY);
66+
}
67+
68+
public InternalAggregation.ReduceContext forFinalReduction() {
69+
return InternalAggregation.ReduceContext.forFinalReduction(
70+
BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, PipelineAggregator.PipelineTree.EMPTY);
71+
};
72+
});
73+
threadPool = new TestThreadPool(SearchPhaseControllerTests.class.getName());
74+
executor = EsExecutors.newFixed("test", 1, 10,
75+
EsExecutors.daemonThreadFactory("test"), threadPool.getThreadContext(), randomBoolean());
76+
}
77+
78+
@After
79+
public void cleanup() {
80+
executor.shutdownNow();
81+
terminate(threadPool);
82+
}
83+
84+
public void testProgressListenerExceptionsAreCaught() throws Exception {
85+
86+
ThrowingSearchProgressListener searchProgressListener = new ThrowingSearchProgressListener();
87+
88+
List<SearchShard> searchShards = new ArrayList<>();
89+
for (int i = 0; i < 10; i++) {
90+
searchShards.add(new SearchShard(null, new ShardId("index", "uuid", i)));
91+
}
92+
searchProgressListener.notifyListShards(searchShards, Collections.emptyList(), SearchResponse.Clusters.EMPTY, false);
93+
94+
SearchRequest searchRequest = new SearchRequest("index");
95+
searchRequest.setBatchedReduceSize(2);
96+
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
97+
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, searchPhaseController,
98+
searchProgressListener, writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
99+
curr.addSuppressed(prev);
100+
return curr;
101+
}));
102+
103+
CountDownLatch partialReduceLatch = new CountDownLatch(10);
104+
105+
for (int i = 0; i < 10; i++) {
106+
SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i),
107+
null, OriginalIndices.NONE);
108+
QuerySearchResult querySearchResult = new QuerySearchResult();
109+
TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
110+
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]);
111+
querySearchResult.setSearchShardTarget(searchShardTarget);
112+
querySearchResult.setShardIndex(i);
113+
queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown);
114+
}
115+
116+
assertEquals(10, searchProgressListener.onQueryResult.get());
117+
assertTrue(partialReduceLatch.await(10, TimeUnit.SECONDS));
118+
assertNull(onPartialMergeFailure.get());
119+
assertEquals(8, searchProgressListener.onPartialReduce.get());
120+
121+
queryPhaseResultConsumer.reduce();
122+
assertEquals(1, searchProgressListener.onFinalReduce.get());
123+
}
124+
125+
private static class ThrowingSearchProgressListener extends SearchProgressListener {
126+
private final AtomicInteger onQueryResult = new AtomicInteger(0);
127+
private final AtomicInteger onPartialReduce = new AtomicInteger(0);
128+
private final AtomicInteger onFinalReduce = new AtomicInteger(0);
129+
130+
@Override
131+
protected void onListShards(List<SearchShard> shards, List<SearchShard> skippedShards, SearchResponse.Clusters clusters,
132+
boolean fetchPhase) {
133+
throw new UnsupportedOperationException();
134+
}
135+
136+
@Override
137+
protected void onQueryResult(int shardIndex) {
138+
onQueryResult.incrementAndGet();
139+
throw new UnsupportedOperationException();
140+
}
141+
142+
@Override
143+
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
144+
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {
145+
onPartialReduce.incrementAndGet();
146+
throw new UnsupportedOperationException();
147+
}
148+
149+
@Override
150+
protected void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
151+
onFinalReduce.incrementAndGet();
152+
throw new UnsupportedOperationException();
153+
}
154+
}
155+
}

0 commit comments

Comments
 (0)