Skip to content

Commit 9eb9f92

Browse files
authored
Adds a consistent shard index to ShardSearchRequest (#65706)
* Adds a consistent shard index to ShardSearchRequest This change ensures that the shard index that is used to tiebreak documents with identical sort remains consistent between two requests that target the same shards. The index is now always computed from the natural order of the shards in the search request. This change also adds the consistent shard index to the ShardSearchRequest. That allows the slice builder to use this information to build more balanced slice query. Relates #56828
1 parent 6c7b41f commit 9eb9f92

22 files changed

+181
-273
lines changed

server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.apache.logging.log4j.Logger;
2323
import org.apache.logging.log4j.message.ParameterizedMessage;
24+
import org.apache.lucene.util.CollectionUtil;
2425
import org.apache.lucene.util.SetOnce;
2526
import org.elasticsearch.ElasticsearchException;
2627
import org.elasticsearch.ExceptionsHelper;
@@ -48,10 +49,9 @@
4849

4950
import java.util.ArrayDeque;
5051
import java.util.ArrayList;
51-
import java.util.Collections;
52+
import java.util.HashMap;
5253
import java.util.List;
5354
import java.util.Map;
54-
import java.util.Set;
5555
import java.util.concurrent.ConcurrentHashMap;
5656
import java.util.concurrent.Executor;
5757
import java.util.concurrent.atomic.AtomicBoolean;
@@ -83,7 +83,6 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
8383
private final ClusterState clusterState;
8484
private final Map<String, AliasFilter> aliasFilter;
8585
private final Map<String, Float> concreteIndexBoosts;
86-
private final Map<String, Set<String>> indexRoutings;
8786
private final SetOnce<AtomicArray<ShardSearchFailure>> shardFailures = new SetOnce<>();
8887
private final Object shardFailuresMutex = new Object();
8988
private final AtomicBoolean hasShardResponse = new AtomicBoolean(false);
@@ -94,6 +93,7 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
9493

9594
protected final GroupShardsIterator<SearchShardIterator> toSkipShardsIts;
9695
protected final GroupShardsIterator<SearchShardIterator> shardsIts;
96+
private final Map<SearchShardIterator, Integer> shardItIndexMap;
9797
private final int expectedTotalOps;
9898
private final AtomicInteger totalOps = new AtomicInteger();
9999
private final int maxConcurrentRequestsPerNode;
@@ -106,7 +106,6 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
106106
AbstractSearchAsyncAction(String name, Logger logger, SearchTransportService searchTransportService,
107107
BiFunction<String, String, Transport.Connection> nodeIdToConnection,
108108
Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
109-
Map<String, Set<String>> indexRoutings,
110109
Executor executor, SearchRequest request,
111110
ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts,
112111
SearchTimeProvider timeProvider, ClusterState clusterState,
@@ -124,6 +123,17 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
124123
}
125124
this.toSkipShardsIts = new GroupShardsIterator<>(toSkipIterators);
126125
this.shardsIts = new GroupShardsIterator<>(iterators);
126+
this.shardItIndexMap = new HashMap<>();
127+
128+
// we compute the shard index based on the natural order of the shards
129+
// that participate in the search request. This means that this number is
130+
// consistent between two requests that target the same shards.
131+
List<SearchShardIterator> naturalOrder = new ArrayList<>(iterators);
132+
CollectionUtil.timSort(naturalOrder);
133+
for (int i = 0; i < naturalOrder.size(); i++) {
134+
shardItIndexMap.put(naturalOrder.get(i), i);
135+
}
136+
127137
// we need to add 1 for non active partition, since we count it in the total. This means for each shard in the iterator we sum up
128138
// it's number of active shards but use 1 as the default if no replica of a shard is active at this point.
129139
// on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
@@ -143,7 +153,6 @@ abstract class AbstractSearchAsyncAction<Result extends SearchPhaseResult> exten
143153
this.clusterState = clusterState;
144154
this.concreteIndexBoosts = concreteIndexBoosts;
145155
this.aliasFilter = aliasFilter;
146-
this.indexRoutings = indexRoutings;
147156
this.results = resultConsumer;
148157
this.clusters = clusters;
149158
}
@@ -210,10 +219,13 @@ public final void run() {
210219
throw new SearchPhaseExecutionException(getName(), msg, null, ShardSearchFailure.EMPTY_ARRAY);
211220
}
212221
}
213-
for (int index = 0; index < shardsIts.size(); index++) {
214-
final SearchShardIterator shardRoutings = shardsIts.get(index);
222+
223+
for (int i = 0; i < shardsIts.size(); i++) {
224+
final SearchShardIterator shardRoutings = shardsIts.get(i);
215225
assert shardRoutings.skip() == false;
216-
performPhaseOnShard(index, shardRoutings, shardRoutings.nextOrNull());
226+
assert shardItIndexMap.containsKey(shardRoutings);
227+
int shardIndex = shardItIndexMap.get(shardRoutings);
228+
performPhaseOnShard(shardIndex, shardRoutings, shardRoutings.nextOrNull());
217229
}
218230
}
219231
}
@@ -651,15 +663,12 @@ public final void onFailure(Exception e) {
651663
}
652664

653665
@Override
654-
public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt) {
666+
public final ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt, int shardIndex) {
655667
AliasFilter filter = aliasFilter.get(shardIt.shardId().getIndex().getUUID());
656668
assert filter != null;
657669
float indexBoost = concreteIndexBoosts.getOrDefault(shardIt.shardId().getIndex().getUUID(), DEFAULT_INDEX_BOOST);
658-
String indexName = shardIt.shardId().getIndex().getName();
659-
final String[] routings = indexRoutings.getOrDefault(indexName, Collections.emptySet())
660-
.toArray(new String[0]);
661-
ShardSearchRequest shardRequest = new ShardSearchRequest(shardIt.getOriginalIndices(), request, shardIt.shardId(), getNumShards(),
662-
filter, indexBoost, timeProvider.getAbsoluteStartMillis(), shardIt.getClusterAlias(), routings,
670+
ShardSearchRequest shardRequest = new ShardSearchRequest(shardIt.getOriginalIndices(), request, shardIt.shardId(), shardIndex,
671+
getNumShards(), filter, indexBoost, timeProvider.getAbsoluteStartMillis(), shardIt.getClusterAlias(),
663672
shardIt.getSearchContextId(), shardIt.getSearchContextKeepAlive());
664673
// if we already received a search result we can inform the shard that it
665674
// can return a null response if the request rewrites to match none rather

server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import java.util.Comparator;
3737
import java.util.List;
3838
import java.util.Map;
39-
import java.util.Set;
4039
import java.util.concurrent.Executor;
4140
import java.util.function.BiFunction;
4241
import java.util.function.Function;
@@ -63,14 +62,13 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction<CanMa
6362
CanMatchPreFilterSearchPhase(Logger logger, SearchTransportService searchTransportService,
6463
BiFunction<String, String, Transport.Connection> nodeIdToConnection,
6564
Map<String, AliasFilter> aliasFilter, Map<String, Float> concreteIndexBoosts,
66-
Map<String, Set<String>> indexRoutings,
6765
Executor executor, SearchRequest request,
6866
ActionListener<SearchResponse> listener, GroupShardsIterator<SearchShardIterator> shardsIts,
6967
TransportSearchAction.SearchTimeProvider timeProvider, ClusterState clusterState,
7068
SearchTask task, Function<GroupShardsIterator<SearchShardIterator>, SearchPhase> phaseFactory,
7169
SearchResponse.Clusters clusters) {
7270
//We set max concurrent shard requests to the number of shards so no throttling happens for can_match requests
73-
super("can_match", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
71+
super("can_match", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts,
7472
executor, request, listener, shardsIts, timeProvider, clusterState, task,
7573
new CanMatchSearchPhaseResults(shardsIts.size()), shardsIts.size(), clusters);
7674
this.phaseFactory = phaseFactory;
@@ -86,7 +84,7 @@ public void addReleasable(Releasable releasable) {
8684
protected void executePhaseOnShard(SearchShardIterator shardIt, SearchShardTarget shard,
8785
SearchActionListener<CanMatchResponse> listener) {
8886
getSearchTransport().sendCanMatch(getConnection(shard.getClusterAlias(), shard.getNodeId()),
89-
buildShardSearchRequest(shardIt), getTask(), listener);
87+
buildShardSearchRequest(shardIt, listener.requestIndex), getTask(), listener);
9088
}
9189

9290
@Override
@@ -149,7 +147,7 @@ private static Comparator<Integer> shardComparator(GroupShardsIterator<SearchSha
149147
MinAndMax<?>[] minAndMaxes,
150148
SortOrder order) {
151149
final Comparator<Integer> comparator = Comparator.comparing(index -> minAndMaxes[index], MinAndMax.getComparator(order));
152-
return comparator.thenComparing(index -> shardsIts.get(index).shardId());
150+
return comparator.thenComparing(index -> shardsIts.get(index));
153151
}
154152

155153
private static final class CanMatchSearchPhaseResults extends SearchPhaseResults<CanMatchResponse> {

server/src/main/java/org/elasticsearch/action/search/FetchSearchPhase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ private void innerRun() throws Exception {
153153
ShardFetchSearchRequest fetchSearchRequest = createFetchRequest(queryResult.queryResult().getContextId(), i, entry,
154154
lastEmittedDocPerShard, searchShardTarget.getOriginalIndices(), queryResult.getShardSearchRequest(),
155155
queryResult.getRescoreDocIds());
156-
executeFetch(i, searchShardTarget, counter, fetchSearchRequest, queryResult.queryResult(),
156+
executeFetch(queryResult.getShardIndex(), searchShardTarget, counter, fetchSearchRequest, queryResult.queryResult(),
157157
connection);
158158
}
159159
}

server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333
import java.util.List;
3434
import java.util.Map;
35-
import java.util.Set;
3635
import java.util.concurrent.Executor;
3736
import java.util.function.BiFunction;
3837

@@ -45,14 +44,14 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
4544
SearchDfsQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService,
4645
final BiFunction<String, String, Transport.Connection> nodeIdToConnection,
4746
final Map<String, AliasFilter> aliasFilter,
48-
final Map<String, Float> concreteIndexBoosts, final Map<String, Set<String>> indexRoutings,
47+
final Map<String, Float> concreteIndexBoosts,
4948
final SearchPhaseController searchPhaseController, final Executor executor,
5049
final QueryPhaseResultConsumer queryPhaseResultConsumer,
5150
final SearchRequest request, final ActionListener<SearchResponse> listener,
5251
final GroupShardsIterator<SearchShardIterator> shardsIts,
5352
final TransportSearchAction.SearchTimeProvider timeProvider,
5453
final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters) {
55-
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
54+
super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts,
5655
executor, request, listener,
5756
shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()),
5857
request.getMaxConcurrentShardRequests(), clusters);
@@ -68,7 +67,7 @@ final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction
6867
protected void executePhaseOnShard(final SearchShardIterator shardIt, final SearchShardTarget shard,
6968
final SearchActionListener<DfsSearchResult> listener) {
7069
getSearchTransport().sendExecuteDfs(getConnection(shard.getClusterAlias(), shard.getNodeId()),
71-
buildShardSearchRequest(shardIt) , getTask(), listener);
70+
buildShardSearchRequest(shardIt, listener.requestIndex) , getTask(), listener);
7271
}
7372

7473
@Override

server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,12 @@ default void sendReleaseSearchContext(ShardSearchContextId contextId,
115115

116116
/**
117117
* Builds an request for the initial search phase.
118+
*
119+
* @param shardIt the target {@link SearchShardIterator}
120+
* @param shardIndex the index of the shard that is used in the coordinator node to
121+
* tiebreak results with identical sort values
118122
*/
119-
ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt);
123+
ShardSearchRequest buildShardSearchRequest(SearchShardIterator shardIt, int shardIndex);
120124

121125
/**
122126
* Processes the phase transition from on phase to another. This method handles all errors that happen during the initial run execution

server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -428,21 +428,25 @@ ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> quer
428428
throw new IllegalStateException(errorMsg);
429429
}
430430
validateMergeSortValueFormats(queryResults);
431-
final QuerySearchResult firstResult = queryResults.stream().findFirst().get().queryResult();
432-
final boolean hasSuggest = firstResult.suggest() != null;
433-
final boolean hasProfileResults = firstResult.hasProfileResults();
431+
final boolean hasSuggest = queryResults.stream().anyMatch(res -> res.queryResult().suggest() != null);
432+
final boolean hasProfileResults = queryResults.stream().anyMatch(res -> res.queryResult().hasProfileResults());
434433

435434
// count the total (we use the query result provider here, since we might not get any hits (we scrolled past them))
436435
final Map<String, List<Suggestion>> groupedSuggestions = hasSuggest ? new HashMap<>() : Collections.emptyMap();
437436
final Map<String, ProfileShardResult> profileResults = hasProfileResults ? new HashMap<>(queryResults.size())
438437
: Collections.emptyMap();
439438
int from = 0;
440439
int size = 0;
440+
DocValueFormat[] sortValueFormats = null;
441441
for (SearchPhaseResult entry : queryResults) {
442442
QuerySearchResult result = entry.queryResult();
443443
from = result.from();
444444
// sorted queries can set the size to 0 if they have enough competitive hits.
445445
size = Math.max(result.size(), size);
446+
if (result.sortValueFormats() != null) {
447+
sortValueFormats = result.sortValueFormats();
448+
}
449+
446450
if (hasSuggest) {
447451
assert result.suggest() != null;
448452
for (Suggestion<? extends Suggestion.Entry<? extends Suggestion.Entry.Option>> suggestion : result.suggest()) {
@@ -477,7 +481,7 @@ ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> quer
477481
final TotalHits totalHits = topDocsStats.getTotalHits();
478482
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.getMaxScore(),
479483
topDocsStats.timedOut, topDocsStats.terminatedEarly, reducedSuggest, aggregations, shardResults, sortedTopDocs,
480-
firstResult.sortValueFormats(), numReducePhases, size, from, false);
484+
sortValueFormats, numReducePhases, size, from, false);
481485
}
482486

483487
private static InternalAggregations reduceAggs(InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,

server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import org.elasticsearch.transport.Transport;
3434

3535
import java.util.Map;
36-
import java.util.Set;
3736
import java.util.concurrent.Executor;
3837
import java.util.function.BiFunction;
3938

@@ -52,14 +51,14 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
5251
SearchQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService,
5352
final BiFunction<String, String, Transport.Connection> nodeIdToConnection,
5453
final Map<String, AliasFilter> aliasFilter,
55-
final Map<String, Float> concreteIndexBoosts, final Map<String, Set<String>> indexRoutings,
54+
final Map<String, Float> concreteIndexBoosts,
5655
final SearchPhaseController searchPhaseController, final Executor executor,
5756
final QueryPhaseResultConsumer resultConsumer, final SearchRequest request,
5857
final ActionListener<SearchResponse> listener,
5958
final GroupShardsIterator<SearchShardIterator> shardsIts,
6059
final TransportSearchAction.SearchTimeProvider timeProvider,
6160
ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters) {
62-
super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings,
61+
super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts,
6362
executor, request, listener, shardsIts, timeProvider, clusterState, task,
6463
resultConsumer, request.getMaxConcurrentShardRequests(), clusters);
6564
this.topDocsSize = getTopDocsSize(request);
@@ -79,7 +78,7 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction<SearchPh
7978
protected void executePhaseOnShard(final SearchShardIterator shardIt,
8079
final SearchShardTarget shard,
8180
final SearchActionListener<SearchPhaseResult> listener) {
82-
ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt));
81+
ShardSearchRequest request = rewriteShardSearchRequest(super.buildShardSearchRequest(shardIt, listener.requestIndex));
8382
getSearchTransport().sendExecuteQuery(getConnection(shard.getClusterAlias(), shard.getNodeId()), request, getTask(), listener);
8483
}
8584

0 commit comments

Comments
 (0)