Skip to content

Commit 181bc80

Browse files
authored
Try to save memory on aggregations (backport of #53793) (#53996)
This delays deserializing the aggregation response try until *right* before we merge the objects.
1 parent 80c24a0 commit 181bc80

File tree

8 files changed

+440
-65
lines changed

8 files changed

+440
-65
lines changed

server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java

+50-26
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,17 @@
1919

2020
package org.elasticsearch.action.search;
2121

22-
import com.carrotsearch.hppc.IntArrayList;
23-
import com.carrotsearch.hppc.ObjectObjectHashMap;
22+
import java.util.ArrayList;
23+
import java.util.Arrays;
24+
import java.util.Collection;
25+
import java.util.Collections;
26+
import java.util.HashMap;
27+
import java.util.List;
28+
import java.util.Map;
29+
import java.util.function.Function;
30+
import java.util.function.IntFunction;
31+
import java.util.function.Supplier;
32+
import java.util.stream.Collectors;
2433

2534
import org.apache.lucene.index.Term;
2635
import org.apache.lucene.search.CollectionStatistics;
@@ -58,16 +67,8 @@
5867
import org.elasticsearch.search.suggest.Suggest.Suggestion;
5968
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
6069

61-
import java.util.ArrayList;
62-
import java.util.Arrays;
63-
import java.util.Collection;
64-
import java.util.Collections;
65-
import java.util.HashMap;
66-
import java.util.List;
67-
import java.util.Map;
68-
import java.util.function.Function;
69-
import java.util.function.IntFunction;
70-
import java.util.stream.Collectors;
70+
import com.carrotsearch.hppc.IntArrayList;
71+
import com.carrotsearch.hppc.ObjectObjectHashMap;
7172

7273
public final class SearchPhaseController {
7374
private static final ScoreDoc[] EMPTY_DOCS = new ScoreDoc[0];
@@ -429,7 +430,7 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
429430
* @see QuerySearchResult#consumeProfileResult()
430431
*/
431432
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
432-
List<InternalAggregations> bufferedAggs, List<TopDocs> bufferedTopDocs,
433+
List<Supplier<InternalAggregations>> bufferedAggs, List<TopDocs> bufferedTopDocs,
433434
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
434435
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
435436
boolean performFinalReduce) {
@@ -453,7 +454,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
453454
final boolean hasSuggest = firstResult.suggest() != null;
454455
final boolean hasProfileResults = firstResult.hasProfileResults();
455456
final boolean consumeAggs;
456-
final List<InternalAggregations> aggregationsList;
457+
final List<Supplier<InternalAggregations>> aggregationsList;
457458
if (bufferedAggs != null) {
458459
consumeAggs = false;
459460
// we already have results from intermediate reduces and just need to perform the final reduce
@@ -492,7 +493,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
492493
}
493494
}
494495
if (consumeAggs) {
495-
aggregationsList.add((InternalAggregations) result.consumeAggs());
496+
aggregationsList.add(result.consumeAggs());
496497
}
497498
if (hasProfileResults) {
498499
String key = result.getSearchShardTarget().toString();
@@ -508,8 +509,7 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
508509
reducedSuggest = new Suggest(Suggest.reduce(groupedSuggestions));
509510
reducedCompletionSuggestions = reducedSuggest.filter(CompletionSuggestion.class);
510511
}
511-
final InternalAggregations aggregations = aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(aggregationsList,
512-
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
512+
final InternalAggregations aggregations = reduceAggs(aggReduceContextBuilder, performFinalReduce, aggregationsList);
513513
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
514514
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size,
515515
reducedCompletionSuggestions);
@@ -519,6 +519,24 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
519519
firstResult.sortValueFormats(), numReducePhases, size, from, false);
520520
}
521521

522+
private InternalAggregations reduceAggs(
523+
InternalAggregation.ReduceContextBuilder aggReduceContextBuilder,
524+
boolean performFinalReduce,
525+
List<Supplier<InternalAggregations>> aggregationsList
526+
) {
527+
/*
528+
* Parse the aggregations, clearing the list as we go so bits backing
529+
* the DelayedWriteable can be collected immediately.
530+
*/
531+
List<InternalAggregations> toReduce = new ArrayList<>(aggregationsList.size());
532+
for (int i = 0; i < aggregationsList.size(); i++) {
533+
toReduce.add(aggregationsList.get(i).get());
534+
aggregationsList.set(i, null);
535+
}
536+
return aggregationsList.isEmpty() ? null : InternalAggregations.topLevelReduce(toReduce,
537+
performFinalReduce ? aggReduceContextBuilder.forFinalReduction() : aggReduceContextBuilder.forPartialReduction());
538+
}
539+
522540
/*
523541
* Returns the size of the requested top documents (from + size)
524542
*/
@@ -600,7 +618,7 @@ public InternalSearchResponse buildResponse(SearchHits hits) {
600618
*/
601619
static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<SearchPhaseResult> {
602620
private final SearchShardTarget[] processedShards;
603-
private final InternalAggregations[] aggsBuffer;
621+
private final Supplier<InternalAggregations>[] aggsBuffer;
604622
private final TopDocs[] topDocsBuffer;
605623
private final boolean hasAggs;
606624
private final boolean hasTopDocs;
@@ -642,7 +660,9 @@ private QueryPhaseResultConsumer(SearchProgressListener progressListener, Search
642660
this.progressListener = progressListener;
643661
this.processedShards = new SearchShardTarget[expectedResultSize];
644662
// no need to buffer anything if we have less expected results. in this case we don't consume any results ahead of time.
645-
this.aggsBuffer = new InternalAggregations[hasAggs ? bufferSize : 0];
663+
@SuppressWarnings("unchecked")
664+
Supplier<InternalAggregations>[] aggsBuffer = new Supplier[hasAggs ? bufferSize : 0];
665+
this.aggsBuffer = aggsBuffer;
646666
this.topDocsBuffer = new TopDocs[hasTopDocs ? bufferSize : 0];
647667
this.hasTopDocs = hasTopDocs;
648668
this.hasAggs = hasAggs;
@@ -665,10 +685,14 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
665685
if (querySearchResult.isNull() == false) {
666686
if (index == bufferSize) {
667687
if (hasAggs) {
668-
ReduceContext reduceContext = aggReduceContextBuilder.forPartialReduction();
669-
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(Arrays.asList(aggsBuffer), reduceContext);
670-
Arrays.fill(aggsBuffer, null);
671-
aggsBuffer[0] = reducedAggs;
688+
List<InternalAggregations> aggs = new ArrayList<>(aggsBuffer.length);
689+
for (int i = 0; i < aggsBuffer.length; i++) {
690+
aggs.add(aggsBuffer[i].get());
691+
aggsBuffer[i] = null; // null the buffer so it can be GCed now.
692+
}
693+
InternalAggregations reducedAggs = InternalAggregations.topLevelReduce(
694+
aggs, aggReduceContextBuilder.forPartialReduction());
695+
aggsBuffer[0] = () -> reducedAggs;
672696
}
673697
if (hasTopDocs) {
674698
TopDocs reducedTopDocs = mergeTopDocs(Arrays.asList(topDocsBuffer),
@@ -681,12 +705,12 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
681705
index = 1;
682706
if (hasAggs || hasTopDocs) {
683707
progressListener.notifyPartialReduce(SearchProgressListener.buildSearchShards(processedShards),
684-
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0] : null, numReducePhases);
708+
topDocsStats.getTotalHits(), hasAggs ? aggsBuffer[0].get() : null, numReducePhases);
685709
}
686710
}
687711
final int i = index++;
688712
if (hasAggs) {
689-
aggsBuffer[i] = (InternalAggregations) querySearchResult.consumeAggs();
713+
aggsBuffer[i] = querySearchResult.consumeAggs();
690714
}
691715
if (hasTopDocs) {
692716
final TopDocsAndMaxScore topDocs = querySearchResult.consumeTopDocs(); // can't be null
@@ -698,7 +722,7 @@ private synchronized void consumeInternal(QuerySearchResult querySearchResult) {
698722
processedShards[querySearchResult.getShardIndex()] = querySearchResult.getSearchShardTarget();
699723
}
700724

701-
private synchronized List<InternalAggregations> getRemainingAggs() {
725+
private synchronized List<Supplier<InternalAggregations>> getRemainingAggs() {
702726
return hasAggs ? Arrays.asList(aggsBuffer).subList(0, index) : null;
703727
}
704728

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.common.io.stream;
21+
22+
import org.elasticsearch.Version;
23+
import org.elasticsearch.common.bytes.BytesReference;
24+
25+
import java.io.IOException;
26+
import java.util.function.Supplier;
27+
28+
/**
29+
* A holder for {@link Writeable}s that can delays reading the underlying
30+
* {@linkplain Writeable} when it is read from a remote node.
31+
*/
32+
public abstract class DelayableWriteable<T extends Writeable> implements Supplier<T>, Writeable {
33+
/**
34+
* Build a {@linkplain DelayableWriteable} that wraps an existing object
35+
* but is serialized so that deserializing it can be delayed.
36+
*/
37+
public static <T extends Writeable> DelayableWriteable<T> referencing(T reference) {
38+
return new Referencing<>(reference);
39+
}
40+
/**
41+
* Build a {@linkplain DelayableWriteable} that copies a buffer from
42+
* the provided {@linkplain StreamInput} and deserializes the buffer
43+
* when {@link Supplier#get()} is called.
44+
*/
45+
public static <T extends Writeable> DelayableWriteable<T> delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
46+
return new Delayed<>(reader, in);
47+
}
48+
49+
private DelayableWriteable() {}
50+
51+
public abstract boolean isDelayed();
52+
53+
private static class Referencing<T extends Writeable> extends DelayableWriteable<T> {
54+
private T reference;
55+
56+
Referencing(T reference) {
57+
this.reference = reference;
58+
}
59+
60+
@Override
61+
public void writeTo(StreamOutput out) throws IOException {
62+
try (BytesStreamOutput buffer = new BytesStreamOutput()) {
63+
reference.writeTo(buffer);
64+
out.writeBytesReference(buffer.bytes());
65+
}
66+
}
67+
68+
@Override
69+
public T get() {
70+
return reference;
71+
}
72+
73+
@Override
74+
public boolean isDelayed() {
75+
return false;
76+
}
77+
}
78+
79+
private static class Delayed<T extends Writeable> extends DelayableWriteable<T> {
80+
private final Writeable.Reader<T> reader;
81+
private final Version remoteVersion;
82+
private final BytesReference serialized;
83+
private final NamedWriteableRegistry registry;
84+
85+
Delayed(Writeable.Reader<T> reader, StreamInput in) throws IOException {
86+
this.reader = reader;
87+
remoteVersion = in.getVersion();
88+
serialized = in.readBytesReference();
89+
registry = in.namedWriteableRegistry();
90+
}
91+
92+
@Override
93+
public void writeTo(StreamOutput out) throws IOException {
94+
if (out.getVersion() == remoteVersion) {
95+
/*
96+
* If the version *does* line up we can just copy the bytes
97+
* which is good because this is how shard request caching
98+
* works.
99+
*/
100+
out.writeBytesReference(serialized);
101+
} else {
102+
/*
103+
* If the version doesn't line up then we have to deserialize
104+
* into the Writeable and re-serialize it against the new
105+
* output stream so it can apply any backwards compatibility
106+
* differences in the wire protocol. This ain't efficient but
107+
* it should be quite rare.
108+
*/
109+
referencing(get()).writeTo(out);
110+
}
111+
}
112+
113+
@Override
114+
public T get() {
115+
try {
116+
try (StreamInput in = registry == null ?
117+
serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
118+
in.setVersion(remoteVersion);
119+
return reader.read(in);
120+
}
121+
} catch (IOException e) {
122+
throw new RuntimeException("unexpected error expanding aggregations", e);
123+
}
124+
}
125+
126+
@Override
127+
public boolean isDelayed() {
128+
return true;
129+
}
130+
}
131+
}

server/src/main/java/org/elasticsearch/common/io/stream/FilterStreamInput.java

+5
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,9 @@ public void setVersion(Version version) {
9494
protected void ensureCanReadBytes(int length) throws EOFException {
9595
delegate.ensureCanReadBytes(length);
9696
}
97+
98+
@Override
99+
public NamedWriteableRegistry namedWriteableRegistry() {
100+
return delegate.namedWriteableRegistry();
101+
}
97102
}

server/src/main/java/org/elasticsearch/common/io/stream/NamedWriteableAwareStreamInput.java

+5
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,9 @@ public <C extends NamedWriteable> C readNamedWriteable(@SuppressWarnings("unused
5252
+ "] than it was read from [" + name + "].";
5353
return c;
5454
}
55+
56+
@Override
57+
public NamedWriteableRegistry namedWriteableRegistry() {
58+
return namedWriteableRegistry;
59+
}
5560
}

server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

+8
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,14 @@ public <T extends Exception> T readException() throws IOException {
10971097
return null;
10981098
}
10991099

1100+
/**
1101+
* Get the registry of named writeables is his stream has one,
1102+
* {@code null} otherwise.
1103+
*/
1104+
public NamedWriteableRegistry namedWriteableRegistry() {
1105+
return null;
1106+
}
1107+
11001108
/**
11011109
* Reads a {@link NamedWriteable} from the current stream, by first reading its name and then looking for
11021110
* the corresponding entry in the registry by name, so that the proper object can be read and returned.

0 commit comments

Comments
 (0)