Skip to content

Commit 1600aee

Browse files
committed
QueryPhaseResultConsumer to call notifyPartialReduce
As part of elastic#60275 QueryPhaseResultConsumer ended up calling SearchProgressListener#onPartialReduce directly instead of notifyPartialReduce. That means we don't catch exceptions that may occur while executing the progress listener callback. This commit fix the call and adds a test for this scenario.
1 parent d15d796 commit 1600aee

File tree

2 files changed

+160
-2
lines changed

2 files changed

+160
-2
lines changed

server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ private MergeResult partialReduce(MergeTask task,
192192
SearchShardTarget target = result.getSearchShardTarget();
193193
processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
194194
}
195-
progressListener.onPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
195+
progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
196196
return new MergeResult(processedShards, newTopDocs, newAggs);
197197
}
198198

@@ -281,7 +281,7 @@ private void onMergeFailure(Exception exc) {
281281
if (task != null) {
282282
toCancel.add(task);
283283
}
284-
queue.stream().forEach(toCancel::add);
284+
toCancel.addAll(queue);
285285
queue.clear();
286286
mergeResult = null;
287287
toCancel.stream().forEach(MergeTask::cancel);
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.action.search;
21+
22+
import org.apache.lucene.search.ScoreDoc;
23+
import org.apache.lucene.search.TopDocs;
24+
import org.apache.lucene.search.TotalHits;
25+
import org.elasticsearch.action.OriginalIndices;
26+
import org.elasticsearch.common.io.stream.DelayableWriteable;
27+
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
28+
import org.elasticsearch.common.util.BigArrays;
29+
import org.elasticsearch.common.util.concurrent.EsExecutors;
30+
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
31+
import org.elasticsearch.index.shard.ShardId;
32+
import org.elasticsearch.search.DocValueFormat;
33+
import org.elasticsearch.search.SearchShardTarget;
34+
import org.elasticsearch.search.aggregations.InternalAggregation;
35+
import org.elasticsearch.search.aggregations.InternalAggregations;
36+
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
37+
import org.elasticsearch.search.query.QuerySearchResult;
38+
import org.elasticsearch.test.ESTestCase;
39+
import org.elasticsearch.threadpool.TestThreadPool;
40+
import org.elasticsearch.threadpool.ThreadPool;
41+
import org.junit.After;
42+
import org.junit.Before;
43+
44+
import java.util.ArrayList;
45+
import java.util.Collections;
46+
import java.util.List;
47+
import java.util.concurrent.CountDownLatch;
48+
import java.util.concurrent.TimeUnit;
49+
import java.util.concurrent.atomic.AtomicInteger;
50+
import java.util.concurrent.atomic.AtomicReference;
51+
52+
public class QueryPhaseResultConsumerTests extends ESTestCase {
53+
54+
private SearchPhaseController searchPhaseController;
55+
private ThreadPool threadPool;
56+
private EsThreadPoolExecutor executor;
57+
58+
@Before
59+
public void setup() {
60+
searchPhaseController = new SearchPhaseController(writableRegistry(),
61+
s -> new InternalAggregation.ReduceContextBuilder() {
62+
@Override
63+
public InternalAggregation.ReduceContext forPartialReduction() {
64+
return InternalAggregation.ReduceContext.forPartialReduction(
65+
BigArrays.NON_RECYCLING_INSTANCE, null, () -> PipelineAggregator.PipelineTree.EMPTY);
66+
}
67+
68+
public InternalAggregation.ReduceContext forFinalReduction() {
69+
return InternalAggregation.ReduceContext.forFinalReduction(
70+
BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, PipelineAggregator.PipelineTree.EMPTY);
71+
};
72+
});
73+
threadPool = new TestThreadPool(SearchPhaseControllerTests.class.getName());
74+
executor = EsExecutors.newFixed("test", 1, 10,
75+
EsExecutors.daemonThreadFactory("test"), threadPool.getThreadContext(), randomBoolean());
76+
77+
}
78+
79+
@After
80+
public void cleanup() {
81+
executor.shutdownNow();
82+
terminate(threadPool);
83+
}
84+
85+
public void testProgressListenerExceptionsAreCaught() throws Exception {
86+
87+
ThrowingSearchProgressListener searchProgressListener = new ThrowingSearchProgressListener();
88+
89+
List<SearchShard> searchShards = new ArrayList<>();
90+
for (int i = 0; i < 10; i++) {
91+
searchShards.add(new SearchShard(null, new ShardId("index", "uuid", i)));
92+
}
93+
searchProgressListener.notifyListShards(searchShards, Collections.emptyList(), SearchResponse.Clusters.EMPTY, false);
94+
95+
SearchRequest searchRequest = new SearchRequest("index");
96+
searchRequest.setBatchedReduceSize(2);
97+
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
98+
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, searchPhaseController,
99+
searchProgressListener, writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
100+
curr.addSuppressed(prev);
101+
return curr;
102+
}));
103+
104+
CountDownLatch partialReduceLatch = new CountDownLatch(10);
105+
106+
for (int i = 0; i < 10; i++) {
107+
SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i),
108+
null, OriginalIndices.NONE);
109+
QuerySearchResult querySearchResult = new QuerySearchResult();
110+
TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
111+
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]);
112+
querySearchResult.setSearchShardTarget(searchShardTarget);
113+
querySearchResult.setShardIndex(i);
114+
queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown);
115+
}
116+
117+
assertEquals(10, searchProgressListener.onQueryResult.get());
118+
assertTrue(partialReduceLatch.await(10, TimeUnit.SECONDS));
119+
if (onPartialMergeFailure.get() != null) {
120+
throw onPartialMergeFailure.get();
121+
}
122+
assertEquals(8, searchProgressListener.onPartialReduce.get());
123+
124+
queryPhaseResultConsumer.reduce();
125+
assertEquals(1, searchProgressListener.onFinalReduce.get());
126+
}
127+
128+
private static class ThrowingSearchProgressListener extends SearchProgressListener {
129+
private final AtomicInteger onQueryResult = new AtomicInteger(0);
130+
private final AtomicInteger onPartialReduce = new AtomicInteger(0);
131+
private final AtomicInteger onFinalReduce = new AtomicInteger(0);
132+
133+
@Override
134+
protected void onListShards(List<SearchShard> shards, List<SearchShard> skippedShards, SearchResponse.Clusters clusters,
135+
boolean fetchPhase) {
136+
throw new UnsupportedOperationException();
137+
}
138+
139+
@Override
140+
protected void onQueryResult(int shardIndex) {
141+
onQueryResult.incrementAndGet();
142+
throw new UnsupportedOperationException();
143+
}
144+
145+
@Override
146+
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
147+
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {
148+
onPartialReduce.incrementAndGet();
149+
throw new UnsupportedOperationException();
150+
}
151+
152+
@Override
153+
protected void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
154+
onFinalReduce.incrementAndGet();
155+
throw new UnsupportedOperationException();
156+
}
157+
}
158+
}

0 commit comments

Comments
 (0)