Skip to content

Commit e225641

Browse files
committed
QueryPhaseResultConsumer to call notifyPartialReduce (elastic#62083)
As part of elastic#60275 QueryPhaseResultConsumer ended up calling SearchProgressListener#onPartialReduce directly instead of notifyPartialReduce. That means we don't catch exceptions that may occur while executing the progress listener callback. This commit fixes the call and adds a test for this scenario.
1 parent 01e4972 commit e225641

File tree

2 files changed

+156
-2
lines changed

2 files changed

+156
-2
lines changed

server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ private MergeResult partialReduce(MergeTask task,
192192
SearchShardTarget target = result.getSearchShardTarget();
193193
processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId()));
194194
}
195-
progressListener.onPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
195+
progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases);
196196
return new MergeResult(processedShards, newTopDocs, newAggs);
197197
}
198198

@@ -281,7 +281,7 @@ private void onMergeFailure(Exception exc) {
281281
if (task != null) {
282282
toCancel.add(task);
283283
}
284-
queue.stream().forEach(toCancel::add);
284+
toCancel.addAll(queue);
285285
queue.clear();
286286
mergeResult = null;
287287
toCancel.stream().forEach(MergeTask::cancel);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.action.search;
21+
22+
import org.apache.lucene.search.ScoreDoc;
23+
import org.apache.lucene.search.TopDocs;
24+
import org.apache.lucene.search.TotalHits;
25+
import org.elasticsearch.action.OriginalIndices;
26+
import org.elasticsearch.common.io.stream.DelayableWriteable;
27+
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
28+
import org.elasticsearch.common.util.BigArrays;
29+
import org.elasticsearch.common.util.concurrent.EsExecutors;
30+
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
31+
import org.elasticsearch.index.shard.ShardId;
32+
import org.elasticsearch.search.DocValueFormat;
33+
import org.elasticsearch.search.SearchShardTarget;
34+
import org.elasticsearch.search.aggregations.InternalAggregation;
35+
import org.elasticsearch.search.aggregations.InternalAggregations;
36+
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
37+
import org.elasticsearch.search.query.QuerySearchResult;
38+
import org.elasticsearch.test.ESTestCase;
39+
import org.elasticsearch.threadpool.TestThreadPool;
40+
import org.elasticsearch.threadpool.ThreadPool;
41+
import org.junit.After;
42+
import org.junit.Before;
43+
44+
import java.util.ArrayList;
45+
import java.util.Collections;
46+
import java.util.List;
47+
import java.util.concurrent.CountDownLatch;
48+
import java.util.concurrent.TimeUnit;
49+
import java.util.concurrent.atomic.AtomicInteger;
50+
import java.util.concurrent.atomic.AtomicReference;
51+
52+
public class QueryPhaseResultConsumerTests extends ESTestCase {
53+
54+
private SearchPhaseController searchPhaseController;
55+
private ThreadPool threadPool;
56+
private EsThreadPoolExecutor executor;
57+
58+
@Before
59+
public void setup() {
60+
searchPhaseController = new SearchPhaseController(writableRegistry(),
61+
s -> new InternalAggregation.ReduceContextBuilder() {
62+
@Override
63+
public InternalAggregation.ReduceContext forPartialReduction() {
64+
return InternalAggregation.ReduceContext.forPartialReduction(
65+
BigArrays.NON_RECYCLING_INSTANCE, null, () -> PipelineAggregator.PipelineTree.EMPTY);
66+
}
67+
68+
public InternalAggregation.ReduceContext forFinalReduction() {
69+
return InternalAggregation.ReduceContext.forFinalReduction(
70+
BigArrays.NON_RECYCLING_INSTANCE, null, b -> {}, PipelineAggregator.PipelineTree.EMPTY);
71+
};
72+
});
73+
threadPool = new TestThreadPool(SearchPhaseControllerTests.class.getName());
74+
executor = EsExecutors.newFixed("test", 1, 10, EsExecutors.daemonThreadFactory("test"), threadPool.getThreadContext());
75+
}
76+
77+
@After
78+
public void cleanup() {
79+
executor.shutdownNow();
80+
terminate(threadPool);
81+
}
82+
83+
public void testProgressListenerExceptionsAreCaught() throws Exception {
84+
85+
ThrowingSearchProgressListener searchProgressListener = new ThrowingSearchProgressListener();
86+
87+
List<SearchShard> searchShards = new ArrayList<>();
88+
for (int i = 0; i < 10; i++) {
89+
searchShards.add(new SearchShard(null, new ShardId("index", "uuid", i)));
90+
}
91+
searchProgressListener.notifyListShards(searchShards, Collections.emptyList(), SearchResponse.Clusters.EMPTY, false);
92+
93+
SearchRequest searchRequest = new SearchRequest("index");
94+
searchRequest.setBatchedReduceSize(2);
95+
AtomicReference<Exception> onPartialMergeFailure = new AtomicReference<>();
96+
QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, searchPhaseController,
97+
searchProgressListener, writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> {
98+
curr.addSuppressed(prev);
99+
return curr;
100+
}));
101+
102+
CountDownLatch partialReduceLatch = new CountDownLatch(10);
103+
104+
for (int i = 0; i < 10; i++) {
105+
SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", i),
106+
null, OriginalIndices.NONE);
107+
QuerySearchResult querySearchResult = new QuerySearchResult();
108+
TopDocs topDocs = new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
109+
querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), new DocValueFormat[0]);
110+
querySearchResult.setSearchShardTarget(searchShardTarget);
111+
querySearchResult.setShardIndex(i);
112+
queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown);
113+
}
114+
115+
assertEquals(10, searchProgressListener.onQueryResult.get());
116+
assertTrue(partialReduceLatch.await(10, TimeUnit.SECONDS));
117+
assertNull(onPartialMergeFailure.get());
118+
assertEquals(8, searchProgressListener.onPartialReduce.get());
119+
120+
queryPhaseResultConsumer.reduce();
121+
assertEquals(1, searchProgressListener.onFinalReduce.get());
122+
}
123+
124+
private static class ThrowingSearchProgressListener extends SearchProgressListener {
125+
private final AtomicInteger onQueryResult = new AtomicInteger(0);
126+
private final AtomicInteger onPartialReduce = new AtomicInteger(0);
127+
private final AtomicInteger onFinalReduce = new AtomicInteger(0);
128+
129+
@Override
130+
protected void onListShards(List<SearchShard> shards, List<SearchShard> skippedShards, SearchResponse.Clusters clusters,
131+
boolean fetchPhase) {
132+
throw new UnsupportedOperationException();
133+
}
134+
135+
@Override
136+
protected void onQueryResult(int shardIndex) {
137+
onQueryResult.incrementAndGet();
138+
throw new UnsupportedOperationException();
139+
}
140+
141+
@Override
142+
protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits,
143+
DelayableWriteable.Serialized<InternalAggregations> aggs, int reducePhase) {
144+
onPartialReduce.incrementAndGet();
145+
throw new UnsupportedOperationException();
146+
}
147+
148+
@Override
149+
protected void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
150+
onFinalReduce.incrementAndGet();
151+
throw new UnsupportedOperationException();
152+
}
153+
}
154+
}

0 commit comments

Comments
 (0)