From 52613ff4f6c2ef662c61ed198c6981f23b21656f Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 24 Sep 2020 14:02:49 +0200 Subject: [PATCH 1/2] Request-level circuit breaker support on coordinating nodes (#62223) This commit allows coordinating node to account the memory used to perform partial and final reduce of aggregations in the request circuit breaker. The search coordinator adds the memory that it used to save and reduce the results of shard aggregations in the request circuit breaker. Before any partial or final reduce, the memory needed to reduce the aggregations is estimated and a CircuitBreakingException} is thrown if exceeds the maximum memory allowed in this breaker. This size is estimated as roughly 1.5 times the size of the serialized aggregations that need to be reduced. This estimation can be completely off for some aggregations but it is corrected with the real size after the reduce completes. If the reduce is successful, we update the circuit breaker to remove the size of the source aggregations and replace the estimation with the serialized size of the newly reduced result. As a follow up we could trigger partial reduces based on the memory accounted in the circuit breaker instead of relying on a static number of shard responses. A simpler follow up that could be done in the mean time is to [reduce the default batch reduce size](https://github.com/elastic/elasticsearch/issues/51857) of blocking search request to a more sane number. Closes #37182 --- .../aggregations/TermsReduceBenchmark.java | 230 ++++++++ .../SearchProgressActionListenerIT.java | 4 +- .../action/search/TransportSearchIT.java | 493 ++++++++++++++++++ .../search/AbstractSearchAsyncAction.java | 20 +- .../search/CanMatchPreFilterSearchPhase.java | 9 +- .../action/search/DfsQueryPhase.java | 12 +- .../search/QueryPhaseResultConsumer.java | 287 ++++++---- .../SearchDfsQueryThenFetchAsyncAction.java | 17 +- .../action/search/SearchPhaseContext.java | 6 + .../action/search/SearchPhaseController.java | 7 +- .../action/search/SearchProgressListener.java | 9 +- .../SearchQueryThenFetchAsyncAction.java | 24 +- .../action/search/TransportSearchAction.java | 15 +- .../aggregations/InternalAggregations.java | 43 ++ .../AbstractSearchAsyncActionTests.java | 2 +- .../action/search/DfsQueryPhaseTests.java | 27 +- .../action/search/FetchSearchPhaseTests.java | 22 +- .../action/search/MockSearchPhaseContext.java | 6 + .../search/QueryPhaseResultConsumerTests.java | 10 +- .../action/search/SearchAsyncActionTests.java | 3 +- .../search/SearchPhaseControllerTests.java | 139 ++--- .../SearchQueryThenFetchAsyncActionTests.java | 13 +- .../TransportSearchActionSingleNodeTests.java | 177 ------- .../InternalAggregationsTests.java | 36 +- .../snapshots/SnapshotResiliencyTests.java | 2 +- .../xpack/search/AsyncSearchTask.java | 10 +- .../xpack/search/AsyncSearchTaskTests.java | 50 +- 27 files changed, 1200 insertions(+), 473 deletions(-) create mode 100644 benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java delete mode 100644 server/src/test/java/org/elasticsearch/action/search/TransportSearchActionSingleNodeTests.java diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java new file mode 100644 index 0000000000000..1b2e5672cfed4 --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java @@ -0,0 +1,230 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.benchmark.search.aggregations; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.QueryPhaseResultConsumer; +import org.elasticsearch.action.search.SearchPhaseController; +import org.elasticsearch.action.search.SearchProgressListener; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.aggregations.BucketOrder; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.InternalAggregations; +import org.elasticsearch.search.aggregations.MultiBucketConsumerService; +import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.query.QuerySearchResult; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.AbstractList; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static java.util.Collections.emptyList; + +@Warmup(iterations = 5) +@Measurement(iterations = 7) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@Fork(value = 1) +public class TermsReduceBenchmark { + private final SearchModule searchModule = new SearchModule(Settings.EMPTY, false, emptyList()); + private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(searchModule.getNamedWriteables()); + private final SearchPhaseController controller = new SearchPhaseController( + namedWriteableRegistry, + req -> new InternalAggregation.ReduceContextBuilder() { + @Override + public InternalAggregation.ReduceContext forPartialReduction() { + return InternalAggregation.ReduceContext.forPartialReduction(null, null, () -> PipelineAggregator.PipelineTree.EMPTY); + } + + @Override + public InternalAggregation.ReduceContext forFinalReduction() { + final MultiBucketConsumerService.MultiBucketConsumer bucketConsumer = new MultiBucketConsumerService.MultiBucketConsumer( + Integer.MAX_VALUE, + new NoneCircuitBreakerService().getBreaker(CircuitBreaker.REQUEST) + ); + return InternalAggregation.ReduceContext.forFinalReduction( + null, + null, + bucketConsumer, + PipelineAggregator.PipelineTree.EMPTY + ); + } + } + ); + + @State(Scope.Benchmark) + public static class TermsList extends AbstractList { + @Param({ "1600172297" }) + long seed; + + @Param({ "64", "128", "512" }) + int numShards; + + @Param({ "100" }) + int topNSize; + + @Param({ "1", "10", "100" }) + int cardinalityFactor; + + List aggsList; + + @Setup + public void setup() { + this.aggsList = new ArrayList<>(); + Random rand = new Random(seed); + int cardinality = cardinalityFactor * topNSize; + BytesRef[] dict = new BytesRef[cardinality]; + for (int i = 0; i < dict.length; i++) { + dict[i] = new BytesRef(Long.toString(rand.nextLong())); + } + for (int i = 0; i < numShards; i++) { + aggsList.add(InternalAggregations.from(Collections.singletonList(newTerms(rand, dict, true)))); + } + } + + private StringTerms newTerms(Random rand, BytesRef[] dict, boolean withNested) { + Set randomTerms = new HashSet<>(); + for (int i = 0; i < topNSize; i++) { + randomTerms.add(dict[rand.nextInt(dict.length)]); + } + List buckets = new ArrayList<>(); + for (BytesRef term : randomTerms) { + InternalAggregations subAggs; + if (withNested) { + subAggs = InternalAggregations.from(Collections.singletonList(newTerms(rand, dict, false))); + } else { + subAggs = InternalAggregations.EMPTY; + } + buckets.add(new StringTerms.Bucket(term, rand.nextInt(10000), subAggs, true, 0L, DocValueFormat.RAW)); + } + + Collections.sort(buckets, (a, b) -> a.compareKey(b)); + return new StringTerms( + "terms", + BucketOrder.key(true), + BucketOrder.count(false), + topNSize, + 1, + Collections.emptyMap(), + DocValueFormat.RAW, + numShards, + true, + 0, + buckets, + 0 + ); + } + + @Override + public InternalAggregations get(int index) { + return aggsList.get(index); + } + + @Override + public int size() { + return aggsList.size(); + } + } + + @Param({ "32", "512" }) + private int bufferSize; + + @Benchmark + public SearchPhaseController.ReducedQueryPhase reduceAggs(TermsList candidateList) throws Exception { + List shards = new ArrayList<>(); + for (int i = 0; i < candidateList.size(); i++) { + QuerySearchResult result = new QuerySearchResult(); + result.setShardIndex(i); + result.from(0); + result.size(0); + result.topDocs( + new TopDocsAndMaxScore( + new TopDocs(new TotalHits(1000, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), new ScoreDoc[0]), + Float.NaN + ), + new DocValueFormat[] { DocValueFormat.RAW } + ); + result.aggregations(candidateList.get(i)); + result.setSearchShardTarget( + new SearchShardTarget("node", new ShardId(new Index("index", "index"), i), null, OriginalIndices.NONE) + ); + shards.add(result); + } + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder().size(0).aggregation(AggregationBuilders.terms("test"))); + request.setBatchedReduceSize(bufferSize); + ExecutorService executor = Executors.newFixedThreadPool(1); + QueryPhaseResultConsumer consumer = new QueryPhaseResultConsumer( + request, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + controller, + SearchProgressListener.NOOP, + namedWriteableRegistry, + shards.size(), + exc -> {} + ); + CountDownLatch latch = new CountDownLatch(shards.size()); + for (int i = 0; i < shards.size(); i++) { + consumer.consumeResult(shards.get(i), () -> latch.countDown()); + } + latch.await(); + SearchPhaseController.ReducedQueryPhase phase = consumer.reduce(); + executor.shutdownNow(); + return phase; + } +} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java index de4d4e66f884a..5fe61b9848a45 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/SearchProgressActionListenerIT.java @@ -23,7 +23,6 @@ import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse; import org.elasticsearch.client.Client; import org.elasticsearch.client.node.NodeClient; -import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.AggregationBuilders; @@ -174,8 +173,7 @@ public void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Except } @Override - public void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) { + public void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { numReduces.incrementAndGet(); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java index 8b707092141e0..72af251447103 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java @@ -19,22 +19,243 @@ package org.elasticsearch.action.search; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.ScoreMode; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.admin.cluster.node.stats.NodeStats; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest; +import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.IndexSettings; +import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.index.query.RangeQueryBuilder; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorBase; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.CardinalityUpperBound; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.bucket.terms.LongTerms; +import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.InternalMax; +import org.elasticsearch.search.aggregations.support.ValueType; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.FetchSubPhase; +import org.elasticsearch.search.fetch.FetchSubPhaseProcessor; +import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.test.ESIntegTestCase; +import java.io.IOException; +import java.util.Collection; import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; public class TransportSearchIT extends ESIntegTestCase { + public static class TestPlugin extends Plugin implements SearchPlugin { + @Override + public List getAggregations() { + return Collections.singletonList( + new AggregationSpec(TestAggregationBuilder.NAME, TestAggregationBuilder::new, TestAggregationBuilder.PARSER) + .addResultReader(InternalMax::new) + ); + } + + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + /** + * Set up a fetch sub phase that throws an exception on indices whose name that start with "boom". + */ + return Collections.singletonList(fetchContext -> new FetchSubPhaseProcessor() { + @Override + public void setNextReader(LeafReaderContext readerContext) { + } + + @Override + public void process(FetchSubPhase.HitContext hitContext) { + if (fetchContext.getIndexName().startsWith("boom")) { + throw new RuntimeException("boom"); + } + } + }); + } + } + + @Override + protected Collection> nodePlugins() { + return Collections.singletonList(TestPlugin.class); + } + + public void testLocalClusterAlias() { + long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); + IndexRequest indexRequest = new IndexRequest("test"); + indexRequest.id("1"); + indexRequest.source("field", "value"); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, + "local", nowInMillis, randomBoolean()); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + SearchHit[] hits = searchResponse.getHits().getHits(); + assertEquals(1, hits.length); + SearchHit hit = hits[0]; + assertEquals("local", hit.getClusterAlias()); + assertEquals("test", hit.getIndex()); + assertEquals("1", hit.getId()); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, + "", nowInMillis, randomBoolean()); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + SearchHit[] hits = searchResponse.getHits().getHits(); + assertEquals(1, hits.length); + SearchHit hit = hits[0]; + assertEquals("", hit.getClusterAlias()); + assertEquals("test", hit.getIndex()); + assertEquals("1", hit.getId()); + } + } + + public void testAbsoluteStartMillis() { + { + IndexRequest indexRequest = new IndexRequest("test-1970.01.01"); + indexRequest.id("1"); + indexRequest.source("date", "1970-01-01"); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + { + IndexRequest indexRequest = new IndexRequest("test-1982.01.01"); + indexRequest.id("1"); + indexRequest.source("date", "1982-01-01"); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + { + SearchRequest searchRequest = new SearchRequest(); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + } + { + SearchRequest searchRequest = new SearchRequest(""); + searchRequest.indicesOptions(IndicesOptions.fromOptions(true, true, true, true)); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(0, searchResponse.getTotalShards()); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + Strings.EMPTY_ARRAY, "", 0, randomBoolean()); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + Strings.EMPTY_ARRAY, "", 0, randomBoolean()); + searchRequest.indices(""); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), + Strings.EMPTY_ARRAY, "", 0, randomBoolean()); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date"); + rangeQuery.gte("1970-01-01"); + rangeQuery.lt("1982-01-01"); + sourceBuilder.query(rangeQuery); + searchRequest.source(sourceBuilder); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(1, searchResponse.getHits().getTotalHits().value); + assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); + } + } + + public void testFinalReduce() { + long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); + { + IndexRequest indexRequest = new IndexRequest("test"); + indexRequest.id("1"); + indexRequest.source("price", 10); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + { + IndexRequest indexRequest = new IndexRequest("test"); + indexRequest.id("2"); + indexRequest.source("price", 100); + IndexResponse indexResponse = client().index(indexRequest).actionGet(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + client().admin().indices().prepareRefresh("test").get(); + + SearchRequest originalRequest = new SearchRequest(); + SearchSourceBuilder source = new SearchSourceBuilder(); + source.size(0); + originalRequest.source(source); + TermsAggregationBuilder terms = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.NUMERIC); + terms.field("price"); + terms.size(1); + source.aggregation(terms); + + { + SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest, + Strings.EMPTY_ARRAY, "remote", nowInMillis, true); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + Aggregations aggregations = searchResponse.getAggregations(); + LongTerms longTerms = aggregations.get("terms"); + assertEquals(1, longTerms.getBuckets().size()); + } + { + SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest, + Strings.EMPTY_ARRAY, "remote", nowInMillis, false); + SearchResponse searchResponse = client().search(searchRequest).actionGet(); + assertEquals(2, searchResponse.getHits().getTotalHits().value); + Aggregations aggregations = searchResponse.getAggregations(); + LongTerms longTerms = aggregations.get("terms"); + assertEquals(2, longTerms.getBuckets().size()); + } + } public void testShardCountLimit() throws Exception { try { @@ -103,4 +324,276 @@ public void testSearchIdle() throws Exception { assertThat(resp.getHits().getTotalHits().value, equalTo(2L)); }); } + + public void testCircuitBreakerReduceFail() throws Exception { + int numShards = randomIntBetween(1, 10); + indexSomeDocs("test", numShards, numShards*3); + + { + final AtomicArray responses = new AtomicArray<>(10); + final CountDownLatch latch = new CountDownLatch(10); + for (int i = 0; i < 10; i++) { + int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3)); + SearchRequest request = client().prepareSearch("test") + .addAggregation(new TestAggregationBuilder("test")) + .setBatchedReduceSize(batchReduceSize) + .request(); + final int index = i; + client().search(request, new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + responses.set(index, true); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + responses.set(index, false); + latch.countDown(); + } + }); + } + latch.await(); + assertThat(responses.asList().size(), equalTo(10)); + for (boolean resp : responses.asList()) { + assertTrue(resp); + } + assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); + } + + try { + Settings settings = Settings.builder() + .put("indices.breaker.request.limit", "1b") + .build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setTransientSettings(settings)); + final Client client = client(); + assertBusy(() -> { + SearchPhaseExecutionException exc = expectThrows(SearchPhaseExecutionException.class, () -> client.prepareSearch("test") + .addAggregation(new TestAggregationBuilder("test")) + .get()); + assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("")); + }); + + final AtomicArray exceptions = new AtomicArray<>(10); + final CountDownLatch latch = new CountDownLatch(10); + for (int i = 0; i < 10; i++) { + int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3)); + SearchRequest request = client().prepareSearch("test") + .addAggregation(new TestAggregationBuilder("test")) + .setBatchedReduceSize(batchReduceSize) + .request(); + final int index = i; + client().search(request, new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exc) { + exceptions.set(index, exc); + latch.countDown(); + } + }); + } + latch.await(); + assertThat(exceptions.asList().size(), equalTo(10)); + for (Exception exc : exceptions.asList()) { + assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("")); + } + assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); + } finally { + Settings settings = Settings.builder() + .putNull("indices.breaker.request.limit") + .build(); + assertAcked(client().admin().cluster().prepareUpdateSettings().setTransientSettings(settings)); + } + } + + public void testCircuitBreakerFetchFail() throws Exception { + int numShards = randomIntBetween(1, 10); + int numDocs = numShards*10; + indexSomeDocs("boom", numShards, numDocs); + + final AtomicArray exceptions = new AtomicArray<>(10); + final CountDownLatch latch = new CountDownLatch(10); + for (int i = 0; i < 10; i++) { + int batchReduceSize = randomIntBetween(2, Math.max(numShards + 1, 3)); + SearchRequest request = client().prepareSearch("boom") + .setBatchedReduceSize(batchReduceSize) + .setAllowPartialSearchResults(false) + .request(); + final int index = i; + client().search(request, new ActionListener() { + @Override + public void onResponse(SearchResponse response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exc) { + exceptions.set(index, exc); + latch.countDown(); + } + }); + } + latch.await(); + assertThat(exceptions.asList().size(), equalTo(10)); + for (Exception exc : exceptions.asList()) { + assertThat(ExceptionsHelper.unwrapCause(exc).getCause().getMessage(), containsString("boom")); + } + assertBusy(() -> assertThat(requestBreakerUsed(), equalTo(0L))); + } + + private void indexSomeDocs(String indexName, int numberOfShards, int numberOfDocs) { + createIndex(indexName, Settings.builder().put("index.number_of_shards", numberOfShards).build()); + + for (int i = 0; i < numberOfDocs; i++) { + IndexResponse indexResponse = client().prepareIndex(indexName, "_doc") + .setSource("number", randomInt()) + .get(); + assertEquals(RestStatus.CREATED, indexResponse.status()); + } + client().admin().indices().prepareRefresh(indexName).get(); + } + + private long requestBreakerUsed() { + NodesStatsResponse stats = client().admin().cluster().prepareNodesStats() + .addMetric(NodesStatsRequest.Metric.BREAKER.metricName()) + .get(); + long estimated = 0; + for (NodeStats nodeStats : stats.getNodes()) { + estimated += nodeStats.getBreaker().getStats(CircuitBreaker.REQUEST).getEstimated(); + } + return estimated; + } + + /** + * A test aggregation that doesn't consume circuit breaker memory when running on shards. + * It is used to test the behavior of the circuit breaker when reducing multiple aggregations + * together (coordinator node). + */ + private static class TestAggregationBuilder extends AbstractAggregationBuilder { + static final String NAME = "test"; + + private static final ObjectParser PARSER = + ObjectParser.fromBuilder(NAME, TestAggregationBuilder::new); + + TestAggregationBuilder(String name) { + super(name); + } + + TestAggregationBuilder(StreamInput input) throws IOException { + super(input); + } + + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + // noop + } + + @Override + protected AggregatorFactory doBuild(QueryShardContext queryShardContext, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException { + return new AggregatorFactory(name, queryShardContext, parent, subFactoriesBuilder, metadata) { + @Override + protected Aggregator createInternal(SearchContext searchContext, + Aggregator parent, + CardinalityUpperBound cardinality, + Map metadata) throws IOException { + return new TestAggregator(name, parent, searchContext); + } + }; + } + + @Override + protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } + + @Override + protected AggregationBuilder shallowCopy(AggregatorFactories.Builder factoriesBuilder, Map metadata) { + return new TestAggregationBuilder(name); + } + + @Override + public BucketCardinality bucketCardinality() { + return BucketCardinality.NONE; + } + + @Override + public String getType() { + return "test"; + } + } + + /** + * A test aggregator that extends {@link Aggregator} instead of {@link AggregatorBase} + * to avoid tripping the circuit breaker when executing on a shard. + */ + private static class TestAggregator extends Aggregator { + private final String name; + private final Aggregator parent; + private final SearchContext context; + + private TestAggregator(String name, Aggregator parent, SearchContext context) { + this.name = name; + this.parent = parent; + this.context = context; + } + + + @Override + public String name() { + return name; + } + + @Override + public SearchContext context() { + return context; + } + + @Override + public Aggregator parent() { + return parent; + } + + @Override + public Aggregator subAggregator(String name) { + return null; + } + + @Override + public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException { + return new InternalAggregation[] { + new InternalMax(name(), Double.NaN, DocValueFormat.RAW, Collections.emptyMap()) + }; + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalMax(name(), Double.NaN, DocValueFormat.RAW, Collections.emptyMap()); + } + + @Override + public void close() {} + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx) throws IOException { + throw new CollectionTerminatedException(); + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + + @Override + public void preCollection() throws IOException {} + + @Override + public void postCollection() throws IOException {} + } } diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 31c805e0b222f..0caca6976c02f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -33,6 +33,8 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.index.shard.ShardId; @@ -77,7 +79,7 @@ abstract class AbstractSearchAsyncAction exten **/ private final BiFunction nodeIdToConnection; private final SearchTask task; - final SearchPhaseResults results; + protected final SearchPhaseResults results; private final ClusterState clusterState; private final Map aliasFilter; private final Map concreteIndexBoosts; @@ -98,6 +100,8 @@ abstract class AbstractSearchAsyncAction exten private final Map pendingExecutionsPerNode = new ConcurrentHashMap<>(); private final boolean throttleConcurrentRequests; + private final List releasables = new ArrayList<>(); + AbstractSearchAsyncAction(String name, Logger logger, SearchTransportService searchTransportService, BiFunction nodeIdToConnection, Map aliasFilter, Map concreteIndexBoosts, @@ -133,7 +137,7 @@ abstract class AbstractSearchAsyncAction exten this.executor = executor; this.request = request; this.task = task; - this.listener = listener; + this.listener = ActionListener.runAfter(listener, this::releaseContext); this.nodeIdToConnection = nodeIdToConnection; this.clusterState = clusterState; this.concreteIndexBoosts = concreteIndexBoosts; @@ -143,6 +147,15 @@ abstract class AbstractSearchAsyncAction exten this.clusters = clusters; } + @Override + public void addReleasable(Releasable releasable) { + releasables.add(releasable); + } + + public void releaseContext() { + Releasables.close(releasables); + } + /** * Builds how long it took to execute the search. */ @@ -529,7 +542,7 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At ShardSearchFailure[] failures = buildShardFailures(); Boolean allowPartialResults = request.allowPartialSearchResults(); assert allowPartialResults != null : "SearchRequest missing setting for allowPartialSearchResults"; - if (request.pointInTimeBuilder() == null && allowPartialResults == false && failures.length > 0) { + if (allowPartialResults == false && failures.length > 0) { raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); } else { final Version minNodeVersion = clusterState.nodes().getMinNodeVersion(); @@ -567,6 +580,7 @@ private void raisePhaseFailure(SearchPhaseExecutionException exception) { } }); } + Releasables.close(releasables); listener.onFailure(exception); } diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index db59c39559ed2..663cb861cc047 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -23,6 +23,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.routing.GroupShardsIterator; +import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.search.SearchService.CanMatchResponse; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -76,6 +77,11 @@ final class CanMatchPreFilterSearchPhase extends AbstractSearchAsyncAction listener) { @@ -84,8 +90,7 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, SearchShardTarge } @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, - SearchPhaseContext context) { + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return phaseFactory.apply(getIterator((CanMatchSearchPhaseResults) results, shardsIts)); } diff --git a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java index 980049e99afc4..e0fe285b730ec 100644 --- a/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/DfsQueryPhase.java @@ -29,7 +29,6 @@ import java.io.IOException; import java.util.List; -import java.util.function.Consumer; import java.util.function.Function; /** @@ -50,18 +49,21 @@ final class DfsQueryPhase extends SearchPhase { DfsQueryPhase(List searchResults, AggregatedDfs dfs, - SearchPhaseController searchPhaseController, + QueryPhaseResultConsumer queryResult, Function, SearchPhase> nextPhaseFactory, - SearchPhaseContext context, Consumer onPartialMergeFailure) { + SearchPhaseContext context) { super("dfs_query"); this.progressListener = context.getTask().getProgressListener(); - this.queryResult = searchPhaseController.newSearchPhaseResults(context, progressListener, - context.getRequest(), context.getNumShards(), onPartialMergeFailure); + this.queryResult = queryResult; this.searchResults = searchResults; this.dfs = dfs; this.nextPhaseFactory = nextPhaseFactory; this.context = context; this.searchTransportService = context.getSearchTransport(); + + // register the release of the query consumer to free up the circuit breaker memory + // at the end of the search + context.addReleasable(queryResult); } @Override diff --git a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java index 860cf46645f2b..24531f90b6fd0 100644 --- a/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java +++ b/server/src/main/java/org/elasticsearch/action/search/QueryPhaseResultConsumer.java @@ -23,8 +23,11 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TopDocs; import org.elasticsearch.action.search.SearchPhaseController.TopDocsStats; -import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.search.SearchPhaseResult; @@ -51,13 +54,16 @@ /** * A {@link ArraySearchPhaseResults} implementation that incrementally reduces aggregation results * as shard results are consumed. - * This implementation can be configured to batch up a certain amount of results and reduce - * them asynchronously in the provided {@link Executor} iff the buffer is exhausted. + * This implementation adds the memory that it used to save and reduce the results of shard aggregations + * in the {@link CircuitBreaker#REQUEST} circuit breaker. Before any partial or final reduce, the memory + * needed to reduce the aggregations is estimated and a {@link CircuitBreakingException} is thrown if it + * exceeds the maximum memory allowed in this breaker. */ -public class QueryPhaseResultConsumer extends ArraySearchPhaseResults { +public class QueryPhaseResultConsumer extends ArraySearchPhaseResults implements Releasable { private static final Logger logger = LogManager.getLogger(QueryPhaseResultConsumer.class); private final Executor executor; + private final CircuitBreaker circuitBreaker; private final SearchPhaseController controller; private final SearchProgressListener progressListener; private final ReduceContextBuilder aggReduceContextBuilder; @@ -71,15 +77,13 @@ public class QueryPhaseResultConsumer extends ArraySearchPhaseResults onPartialMergeFailure; - private volatile long aggsMaxBufferSize; - private volatile long aggsCurrentBufferSize; - /** * Creates a {@link QueryPhaseResultConsumer} that incrementally reduces aggregation results * as shard results are consumed. */ public QueryPhaseResultConsumer(SearchRequest request, Executor executor, + CircuitBreaker circuitBreaker, SearchPhaseController controller, SearchProgressListener progressListener, NamedWriteableRegistry namedWriteableRegistry, @@ -87,6 +91,7 @@ public QueryPhaseResultConsumer(SearchRequest request, Consumer onPartialMergeFailure) { super(expectedResultSize); this.executor = executor; + this.circuitBreaker = circuitBreaker; this.controller = controller; this.progressListener = progressListener; this.aggReduceContextBuilder = controller.getReduceContext(request); @@ -94,11 +99,17 @@ public QueryPhaseResultConsumer(SearchRequest request, this.topNSize = getTopDocsSize(request); this.performFinalReduce = request.isFinalReduce(); this.onPartialMergeFailure = onPartialMergeFailure; + SearchSourceBuilder source = request.source(); this.hasTopDocs = source == null || source.size() != 0; this.hasAggs = source != null && source.aggregations() != null; - int bufferSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize; - this.pendingMerges = new PendingMerges(bufferSize, request.resolveTrackTotalHitsUpTo()); + int batchReduceSize = (hasAggs || hasTopDocs) ? Math.min(request.getBatchedReduceSize(), expectedResultSize) : expectedResultSize; + this.pendingMerges = new PendingMerges(batchReduceSize, request.resolveTrackTotalHitsUpTo()); + } + + @Override + public void close() { + Releasables.close(pendingMerges); } @Override @@ -117,28 +128,35 @@ public SearchPhaseController.ReducedQueryPhase reduce() throws Exception { throw pendingMerges.getFailure(); } - logger.trace("aggs final reduction [{}] max [{}]", aggsCurrentBufferSize, aggsMaxBufferSize); // ensure consistent ordering pendingMerges.sortBuffer(); final TopDocsStats topDocsStats = pendingMerges.consumeTopDocsStats(); final List topDocsList = pendingMerges.consumeTopDocs(); final List aggsList = pendingMerges.consumeAggs(); + long breakerSize = pendingMerges.circuitBreakerBytes; + if (hasAggs) { + // Add an estimate of the final reduce size + breakerSize = pendingMerges.addEstimateAndMaybeBreak(pendingMerges.estimateRamBytesUsedForReduce(breakerSize)); + } SearchPhaseController.ReducedQueryPhase reducePhase = controller.reducedQueryPhase(results.asList(), aggsList, topDocsList, topDocsStats, pendingMerges.numReducePhases, false, aggReduceContextBuilder, performFinalReduce); + if (hasAggs) { + // Update the circuit breaker to replace the estimation with the serialized size of the newly reduced result + long finalSize = reducePhase.aggregations.getSerializedSize() - breakerSize; + pendingMerges.addWithoutBreaking(finalSize); + logger.trace("aggs final reduction [{}] max [{}]", + pendingMerges.aggsCurrentBufferSize, pendingMerges.maxAggsCurrentBufferSize); + } progressListener.notifyFinalReduce(SearchProgressListener.buildSearchShards(results.asList()), reducePhase.totalHits, reducePhase.aggregations, reducePhase.numReducePhases); return reducePhase; } - private MergeResult partialReduce(MergeTask task, + private MergeResult partialReduce(QuerySearchResult[] toConsume, + List emptyResults, TopDocsStats topDocsStats, MergeResult lastMerge, int numReducePhases) { - final QuerySearchResult[] toConsume = task.consumeBuffer(); - if (toConsume == null) { - // the task is cancelled - return null; - } // ensure consistent ordering Arrays.sort(toConsume, Comparator.comparingInt(QuerySearchResult::getShardIndex)); @@ -164,27 +182,20 @@ private MergeResult partialReduce(MergeTask task, newTopDocs = null; } - final DelayableWriteable.Serialized newAggs; + final InternalAggregations newAggs; if (hasAggs) { List aggsList = new ArrayList<>(); if (lastMerge != null) { - aggsList.add(lastMerge.reducedAggs.expand()); + aggsList.add(lastMerge.reducedAggs); } for (QuerySearchResult result : toConsume) { aggsList.add(result.consumeAggs().expand()); } - InternalAggregations result = InternalAggregations.topLevelReduce(aggsList, - aggReduceContextBuilder.forPartialReduction()); - newAggs = DelayableWriteable.referencing(result).asSerialized(InternalAggregations::readFrom, namedWriteableRegistry); - long previousBufferSize = aggsCurrentBufferSize; - aggsCurrentBufferSize = newAggs.ramBytesUsed(); - aggsMaxBufferSize = Math.max(aggsCurrentBufferSize, aggsMaxBufferSize); - logger.trace("aggs partial reduction [{}->{}] max [{}]", - previousBufferSize, aggsCurrentBufferSize, aggsMaxBufferSize); + newAggs = InternalAggregations.topLevelReduce(aggsList, aggReduceContextBuilder.forPartialReduction()); } else { newAggs = null; } - List processedShards = new ArrayList<>(task.emptyResults); + List processedShards = new ArrayList<>(emptyResults); if (lastMerge != null) { processedShards.addAll(lastMerge.processedShards); } @@ -193,49 +204,109 @@ private MergeResult partialReduce(MergeTask task, processedShards.add(new SearchShard(target.getClusterAlias(), target.getShardId())); } progressListener.notifyPartialReduce(processedShards, topDocsStats.getTotalHits(), newAggs, numReducePhases); - return new MergeResult(processedShards, newTopDocs, newAggs); + // we leave the results un-serialized because serializing is slow but we compute the serialized + // size as an estimate of the memory used by the newly reduced aggregations. + long serializedSize = hasAggs ? newAggs.getSerializedSize() : 0; + return new MergeResult(processedShards, newTopDocs, newAggs, hasAggs ? serializedSize : 0); } public int getNumReducePhases() { return pendingMerges.numReducePhases; } - private class PendingMerges { - private final int bufferSize; - - private int index; - private final QuerySearchResult[] buffer; + private class PendingMerges implements Releasable { + private final int batchReduceSize; + private final List buffer = new ArrayList<>(); private final List emptyResults = new ArrayList<>(); + // the memory that is accounted in the circuit breaker for this consumer + private volatile long circuitBreakerBytes; + // the memory that is currently used in the buffer + private volatile long aggsCurrentBufferSize; + private volatile long maxAggsCurrentBufferSize = 0; - private final TopDocsStats topDocsStats; - private MergeResult mergeResult; private final ArrayDeque queue = new ArrayDeque<>(); private final AtomicReference runningTask = new AtomicReference<>(); private final AtomicReference failure = new AtomicReference<>(); - private boolean hasPartialReduce; - private int numReducePhases; + private final TopDocsStats topDocsStats; + private volatile MergeResult mergeResult; + private volatile boolean hasPartialReduce; + private volatile int numReducePhases; - PendingMerges(int bufferSize, int trackTotalHitsUpTo) { - this.bufferSize = bufferSize; + PendingMerges(int batchReduceSize, int trackTotalHitsUpTo) { + this.batchReduceSize = batchReduceSize; this.topDocsStats = new TopDocsStats(trackTotalHitsUpTo); - this.buffer = new QuerySearchResult[bufferSize]; } - public boolean hasFailure() { + @Override + public synchronized void close() { + assert hasPendingMerges() == false : "cannot close with partial reduce in-flight"; + if (hasFailure()) { + assert circuitBreakerBytes == 0; + return; + } + assert circuitBreakerBytes >= 0; + circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); + circuitBreakerBytes = 0; + } + + synchronized Exception getFailure() { + return failure.get(); + } + + boolean hasFailure() { return failure.get() != null; } - public synchronized boolean hasPendingMerges() { + boolean hasPendingMerges() { return queue.isEmpty() == false || runningTask.get() != null; } - public synchronized void sortBuffer() { - if (index > 0) { - Arrays.sort(buffer, 0, index, Comparator.comparingInt(QuerySearchResult::getShardIndex)); + void sortBuffer() { + if (buffer.size() > 0) { + Collections.sort(buffer, Comparator.comparingInt(QuerySearchResult::getShardIndex)); } } + synchronized long addWithoutBreaking(long size) { + circuitBreaker.addWithoutBreaking(size); + circuitBreakerBytes += size; + maxAggsCurrentBufferSize = Math.max(maxAggsCurrentBufferSize, circuitBreakerBytes); + return circuitBreakerBytes; + } + + synchronized long addEstimateAndMaybeBreak(long estimatedSize) { + circuitBreaker.addEstimateBytesAndMaybeBreak(estimatedSize, ""); + circuitBreakerBytes += estimatedSize; + maxAggsCurrentBufferSize = Math.max(maxAggsCurrentBufferSize, circuitBreakerBytes); + return circuitBreakerBytes; + } + + /** + * Returns the size of the serialized aggregation that is contained in the + * provided {@link QuerySearchResult}. + */ + long ramBytesUsedQueryResult(QuerySearchResult result) { + if (hasAggs == false) { + return 0; + } + return result.aggregations() + .asSerialized(InternalAggregations::readFrom, namedWriteableRegistry) + .ramBytesUsed(); + } + + /** + * Returns an estimation of the size that a reduce of the provided size + * would take on memory. + * This size is estimated as roughly 1.5 times the size of the serialized + * aggregations that need to be reduced. This estimation can be completely + * off for some aggregations but it is corrected with the real size after + * the reduce completes. + */ + long estimateRamBytesUsedForReduce(long size) { + return Math.round(1.5d * size - size); + } + public void consume(QuerySearchResult result, Runnable next) { boolean executeNextImmediately = true; synchronized (this) { @@ -247,20 +318,24 @@ public void consume(QuerySearchResult result, Runnable next) { } } else { // add one if a partial merge is pending - int size = index + (hasPartialReduce ? 1 : 0); - if (size >= bufferSize) { + int size = buffer.size() + (hasPartialReduce ? 1 : 0); + if (size >= batchReduceSize) { hasPartialReduce = true; executeNextImmediately = false; - QuerySearchResult[] clone = new QuerySearchResult[index]; - System.arraycopy(buffer, 0, clone, 0, index); - MergeTask task = new MergeTask(clone, new ArrayList<>(emptyResults), next); - Arrays.fill(buffer, null); + QuerySearchResult[] clone = buffer.stream().toArray(QuerySearchResult[]::new); + MergeTask task = new MergeTask(clone, aggsCurrentBufferSize, new ArrayList<>(emptyResults), next); + aggsCurrentBufferSize = 0; + buffer.clear(); emptyResults.clear(); - index = 0; queue.add(task); tryExecuteNext(); } - buffer[index++] = result; + if (hasAggs) { + long aggsSize = ramBytesUsedQueryResult(result); + addWithoutBreaking(aggsSize); + aggsCurrentBufferSize += aggsSize; + } + buffer.add(result); } } if (executeNextImmediately) { @@ -268,56 +343,85 @@ public void consume(QuerySearchResult result, Runnable next) { } } - private void onMergeFailure(Exception exc) { - synchronized (this) { - if (failure.get() != null) { - return; - } - failure.compareAndSet(null, exc); - MergeTask task = runningTask.get(); - runningTask.compareAndSet(task, null); - onPartialMergeFailure.accept(exc); - List toCancel = new ArrayList<>(); - if (task != null) { - toCancel.add(task); - } - toCancel.addAll(queue); - queue.clear(); - mergeResult = null; - toCancel.stream().forEach(MergeTask::cancel); + private synchronized void onMergeFailure(Exception exc) { + if (hasFailure()) { + assert circuitBreakerBytes == 0; + return; + } + assert circuitBreakerBytes >= 0; + if (circuitBreakerBytes > 0) { + // make sure that we reset the circuit breaker + circuitBreaker.addWithoutBreaking(-circuitBreakerBytes); + circuitBreakerBytes = 0; + } + failure.compareAndSet(null, exc); + MergeTask task = runningTask.get(); + runningTask.compareAndSet(task, null); + onPartialMergeFailure.accept(exc); + List toCancels = new ArrayList<>(); + if (task != null) { + toCancels.add(task); + } + queue.stream().forEach(toCancels::add); + queue.clear(); + mergeResult = null; + for (MergeTask toCancel : toCancels) { + toCancel.cancel(); } } - private void onAfterMerge(MergeTask task, MergeResult newResult) { + private void onAfterMerge(MergeTask task, MergeResult newResult, long estimatedSize) { synchronized (this) { + if (hasFailure()) { + return; + } runningTask.compareAndSet(task, null); mergeResult = newResult; + if (hasAggs) { + // Update the circuit breaker to remove the size of the source aggregations + // and replace the estimation with the serialized size of the newly reduced result. + long newSize = mergeResult.estimatedSize - estimatedSize; + addWithoutBreaking(newSize); + logger.trace("aggs partial reduction [{}->{}] max [{}]", + estimatedSize, mergeResult.estimatedSize, maxAggsCurrentBufferSize); + } + task.consumeListener(); } - task.consumeListener(); } private void tryExecuteNext() { final MergeTask task; synchronized (this) { if (queue.isEmpty() - || failure.get() != null + || hasFailure() || runningTask.get() != null) { return; } task = queue.poll(); runningTask.compareAndSet(null, task); } + executor.execute(new AbstractRunnable() { @Override protected void doRun() { + final MergeResult thisMergeResult = mergeResult; + long estimatedTotalSize = (thisMergeResult != null ? thisMergeResult.estimatedSize : 0) + task.aggsBufferSize; final MergeResult newMerge; try { - newMerge = partialReduce(task, topDocsStats, mergeResult, ++numReducePhases); + final QuerySearchResult[] toConsume = task.consumeBuffer(); + if (toConsume == null) { + return; + } + long estimatedMergeSize = estimateRamBytesUsedForReduce(estimatedTotalSize); + addEstimateAndMaybeBreak(estimatedMergeSize); + estimatedTotalSize += estimatedMergeSize; + ++ numReducePhases; + newMerge = partialReduce(toConsume, task.emptyResults, topDocsStats, thisMergeResult, numReducePhases); } catch (Exception t) { onMergeFailure(t); return; } - onAfterMerge(task, newMerge); + onAfterMerge(task, newMerge, estimatedTotalSize); tryExecuteNext(); } @@ -328,15 +432,14 @@ public void onFailure(Exception exc) { }); } - public TopDocsStats consumeTopDocsStats() { - for (int i = 0; i < index; i++) { - QuerySearchResult result = buffer[i]; + public synchronized TopDocsStats consumeTopDocsStats() { + for (QuerySearchResult result : buffer) { topDocsStats.add(result.topDocs(), result.searchTimedOut(), result.terminatedEarly()); } return topDocsStats; } - public List consumeTopDocs() { + public synchronized List consumeTopDocs() { if (hasTopDocs == false) { return Collections.emptyList(); } @@ -344,8 +447,7 @@ public List consumeTopDocs() { if (mergeResult != null) { topDocsList.add(mergeResult.reducedTopDocs); } - for (int i = 0; i < index; i++) { - QuerySearchResult result = buffer[i]; + for (QuerySearchResult result : buffer) { TopDocsAndMaxScore topDocs = result.consumeTopDocs(); setShardIndex(topDocs.topDocs, result.getShardIndex()); topDocsList.add(topDocs.topDocs); @@ -353,46 +455,45 @@ public List consumeTopDocs() { return topDocsList; } - public List consumeAggs() { + public synchronized List consumeAggs() { if (hasAggs == false) { return Collections.emptyList(); } List aggsList = new ArrayList<>(); if (mergeResult != null) { - aggsList.add(mergeResult.reducedAggs.expand()); + aggsList.add(mergeResult.reducedAggs); } - for (int i = 0; i < index; i++) { - QuerySearchResult result = buffer[i]; + for (QuerySearchResult result : buffer) { aggsList.add(result.consumeAggs().expand()); } return aggsList; } - - public Exception getFailure() { - return failure.get(); - } } private static class MergeResult { private final List processedShards; private final TopDocs reducedTopDocs; - private final DelayableWriteable.Serialized reducedAggs; + private final InternalAggregations reducedAggs; + private final long estimatedSize; private MergeResult(List processedShards, TopDocs reducedTopDocs, - DelayableWriteable.Serialized reducedAggs) { + InternalAggregations reducedAggs, long estimatedSize) { this.processedShards = processedShards; this.reducedTopDocs = reducedTopDocs; this.reducedAggs = reducedAggs; + this.estimatedSize = estimatedSize; } } private static class MergeTask { private final List emptyResults; private QuerySearchResult[] buffer; + private long aggsBufferSize; private Runnable next; - private MergeTask(QuerySearchResult[] buffer, List emptyResults, Runnable next) { + private MergeTask(QuerySearchResult[] buffer, long aggsBufferSize, List emptyResults, Runnable next) { this.buffer = buffer; + this.aggsBufferSize = aggsBufferSize; this.emptyResults = emptyResults; this.next = next; } @@ -403,7 +504,7 @@ public synchronized QuerySearchResult[] consumeBuffer() { return toRet; } - public synchronized void consumeListener() { + public void consumeListener() { if (next != null) { next.run(); next = null; diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java index 0762d70dc5cbf..b53e635d9866a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchDfsQueryThenFetchAsyncAction.java @@ -35,29 +35,29 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.function.BiFunction; -import java.util.function.Consumer; final class SearchDfsQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { private final SearchPhaseController searchPhaseController; - private final Consumer onPartialMergeFailure; + + private final QueryPhaseResultConsumer queryPhaseResultConsumer; SearchDfsQueryThenFetchAsyncAction(final Logger logger, final SearchTransportService searchTransportService, final BiFunction nodeIdToConnection, final Map aliasFilter, final Map concreteIndexBoosts, final Map> indexRoutings, final SearchPhaseController searchPhaseController, final Executor executor, + final QueryPhaseResultConsumer queryPhaseResultConsumer, final SearchRequest request, final ActionListener listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, - final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters, - Consumer onPartialMergeFailure) { + final ClusterState clusterState, final SearchTask task, SearchResponse.Clusters clusters) { super("dfs", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, executor, request, listener, shardsIts, timeProvider, clusterState, task, new ArraySearchPhaseResults<>(shardsIts.size()), request.getMaxConcurrentShardRequests(), clusters); + this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.searchPhaseController = searchPhaseController; - this.onPartialMergeFailure = onPartialMergeFailure; SearchProgressListener progressListener = task.getProgressListener(); SearchSourceBuilder sourceBuilder = request.source(); progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts), @@ -72,11 +72,12 @@ protected void executePhaseOnShard(final SearchShardIterator shardIt, final Sear } @Override - protected SearchPhase getNextPhase(final SearchPhaseResults results, final SearchPhaseContext context) { + protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { final List dfsSearchResults = results.getAtomicArray().asList(); final AggregatedDfs aggregatedDfs = searchPhaseController.aggregateDfs(dfsSearchResults); - return new DfsQueryPhase(dfsSearchResults, aggregatedDfs, searchPhaseController, (queryResults) -> - new FetchSearchPhase(queryResults, searchPhaseController, aggregatedDfs, context), context, onPartialMergeFailure); + return new DfsQueryPhase(dfsSearchResults, aggregatedDfs, queryPhaseResultConsumer, + (queryResults) -> new FetchSearchPhase(queryResults, searchPhaseController, aggregatedDfs, context), + context); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java index 75ce64dc264eb..e56100dc5287f 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseContext.java @@ -21,6 +21,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; @@ -123,4 +124,9 @@ default void sendReleaseSearchContext(ShardSearchContextId contextId, * a response is returned to the user indicating that all shards have failed. */ void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase); + + /** + * Registers a {@link Releasable} that will be closed when the search request finishes or fails. + */ + void addReleasable(Releasable releasable); } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java index a612e09a549f9..21dc1589c6579 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchPhaseController.java @@ -34,6 +34,7 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits.Relation; import org.apache.lucene.search.grouping.CollapseTopFieldDocs; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.collect.HppcMaps; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; @@ -563,14 +564,16 @@ InternalAggregation.ReduceContextBuilder getReduceContext(SearchRequest request) } /** - * Returns a new {@link QueryPhaseResultConsumer} instance. This might return an instance that reduces search responses incrementally. + * Returns a new {@link QueryPhaseResultConsumer} instance that reduces search responses incrementally. */ QueryPhaseResultConsumer newSearchPhaseResults(Executor executor, + CircuitBreaker circuitBreaker, SearchProgressListener listener, SearchRequest request, int numShards, Consumer onPartialMergeFailure) { - return new QueryPhaseResultConsumer(request, executor, this, listener, namedWriteableRegistry, numShards, onPartialMergeFailure); + return new QueryPhaseResultConsumer(request, executor, circuitBreaker, + this, listener, namedWriteableRegistry, numShards, onPartialMergeFailure); } static final class TopDocsStats { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java index bbb9a5ad02388..f6670eb5e2f5c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchProgressListener.java @@ -25,7 +25,6 @@ import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.search.SearchResponse.Clusters; import org.elasticsearch.cluster.routing.GroupShardsIterator; -import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -78,11 +77,10 @@ protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc * * @param shards The list of shards that are part of this reduce. * @param totalHits The total number of hits in this reduce. - * @param aggs The partial result for aggregations stored in serialized form. + * @param aggs The partial result for aggregations. * @param reducePhase The version number for this reduce. */ - protected void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) {} + protected void onPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {} /** * Executed once when the final reduce is created. @@ -137,8 +135,7 @@ final void notifyQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exc } } - final void notifyPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) { + final void notifyPartialReduce(List shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) { try { onPartialReduce(shards, totalHits, aggs, reducePhase); } catch (Exception e) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index f841c6e55f44b..79f5e5ca9571e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -26,7 +26,6 @@ import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; -import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; @@ -37,7 +36,6 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.function.BiFunction; -import java.util.function.Consumer; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; @@ -56,22 +54,26 @@ class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction aliasFilter, final Map concreteIndexBoosts, final Map> indexRoutings, final SearchPhaseController searchPhaseController, final Executor executor, - final SearchRequest request, final ActionListener listener, + final QueryPhaseResultConsumer resultConsumer, final SearchRequest request, + final ActionListener listener, final GroupShardsIterator shardsIts, final TransportSearchAction.SearchTimeProvider timeProvider, - ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters, - Consumer onPartialMergeFailure) { + ClusterState clusterState, SearchTask task, SearchResponse.Clusters clusters) { super("query", logger, searchTransportService, nodeIdToConnection, aliasFilter, concreteIndexBoosts, indexRoutings, executor, request, listener, shardsIts, timeProvider, clusterState, task, - searchPhaseController.newSearchPhaseResults(executor, task.getProgressListener(), - request, shardsIts.size(), onPartialMergeFailure), request.getMaxConcurrentShardRequests(), clusters); + resultConsumer, request.getMaxConcurrentShardRequests(), clusters); this.topDocsSize = getTopDocsSize(request); this.trackTotalHitsUpTo = request.resolveTrackTotalHitsUpTo(); this.searchPhaseController = searchPhaseController; this.progressListener = task.getProgressListener(); - final SearchSourceBuilder sourceBuilder = request.source(); + + // register the release of the query consumer to free up the circuit breaker memory + // at the end of the search + addReleasable(resultConsumer); + + boolean hasFetchPhase = request.source() == null ? true : request.source().size() > 0; progressListener.notifyListShards(SearchProgressListener.buildSearchShards(this.shardsIts), - SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, sourceBuilder == null || sourceBuilder.size() != 0); + SearchProgressListener.buildSearchShards(toSkipShardsIts), clusters, hasFetchPhase); } protected void executePhaseOnShard(final SearchShardIterator shardIt, @@ -108,8 +110,8 @@ && getRequest().scroll() == null } @Override - protected SearchPhase getNextPhase(final SearchPhaseResults results, final SearchPhaseContext context) { - return new FetchSearchPhase(results, searchPhaseController, null, context); + protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { + return new FetchSearchPhase(results, searchPhaseController, null, this); } private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) { diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index aab586fa47e65..2c9a5f9e37e53 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -43,6 +43,7 @@ import org.elasticsearch.cluster.routing.ShardIterator; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -55,6 +56,7 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; @@ -115,10 +117,12 @@ public class TransportSearchAction extends HandledTransportAction) SearchRequest::new); this.client = client; this.threadPool = threadPool; + this.circuitBreaker = circuitBreakerService.getBreaker(CircuitBreaker.REQUEST); this.searchPhaseController = searchPhaseController; this.searchTransportService = searchTransportService; this.remoteClusterService = searchTransportService.getRemoteClusterService(); @@ -796,17 +801,19 @@ public void run() { }; }, clusters); } else { + final QueryPhaseResultConsumer queryResultConsumer = searchPhaseController.newSearchPhaseResults(executor, + circuitBreaker, task.getProgressListener(), searchRequest, shardIterators.size(), exc -> cancelTask(task, exc)); AbstractSearchAsyncAction searchAsyncAction; switch (searchRequest.searchType()) { case DFS_QUERY_THEN_FETCH: searchAsyncAction = new SearchDfsQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup, - aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener, - shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc)); + aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, + executor, queryResultConsumer, searchRequest, listener, shardIterators, timeProvider, clusterState, task, clusters); break; case QUERY_THEN_FETCH: searchAsyncAction = new SearchQueryThenFetchAsyncAction(logger, searchTransportService, connectionLookup, - aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, searchRequest, listener, - shardIterators, timeProvider, clusterState, task, clusters, exc -> cancelTask(task, exc)); + aliasFilter, concreteIndexBoosts, indexRoutings, searchPhaseController, executor, queryResultConsumer, + searchRequest, listener, shardIterators, timeProvider, clusterState, task, clusters); break; default: throw new IllegalStateException("Unknown search type: [" + searchRequest.searchType() + "]"); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java b/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java index b16f5cdd3b1f8..1522a789cd7c1 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/InternalAggregations.java @@ -272,4 +272,47 @@ public static InternalAggregations reduce(List aggregation public static InternalAggregations reduce(List aggregationsList, ReduceContext context) { return reduce(aggregationsList, context, InternalAggregations::from); } + + /** + * Returns the number of bytes required to serialize these aggregations in binary form. + */ + public long getSerializedSize() { + try (CountingStreamOutput out = new CountingStreamOutput()) { + out.setVersion(Version.CURRENT); + writeTo(out); + return out.size; + } catch (IOException exc) { + // should never happen + throw new RuntimeException(exc); + } + } + + private static class CountingStreamOutput extends StreamOutput { + long size = 0; + + @Override + public void writeByte(byte b) throws IOException { + ++ size; + } + + @Override + public void writeBytes(byte[] b, int offset, int length) throws IOException { + size += length; + } + + @Override + public void flush() throws IOException {} + + @Override + public void close() throws IOException {} + + @Override + public void reset() throws IOException { + size = 0; + } + + public long length() { + return size; + } + } } diff --git a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java index 916f9111517a5..9f1199f774a63 100644 --- a/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/AbstractSearchAsyncActionTests.java @@ -96,7 +96,7 @@ private AbstractSearchAsyncAction createAction(SearchRequest results, request.getMaxConcurrentShardRequests(), SearchResponse.Clusters.EMPTY) { @Override - protected SearchPhase getNextPhase(final SearchPhaseResults results, final SearchPhaseContext context) { + protected SearchPhase getNextPhase(final SearchPhaseResults results, SearchPhaseContext context) { return null; } diff --git a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java index 140d28c47fd9a..d71b14f3d12f3 100644 --- a/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/DfsQueryPhaseTests.java @@ -25,8 +25,11 @@ import org.apache.lucene.search.TotalHits; import org.apache.lucene.store.MockDirectoryWrapper; import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchPhaseResult; @@ -86,15 +89,19 @@ public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest } } }; + SearchPhaseController searchPhaseController = searchPhaseController(); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); mockSearchPhaseContext.searchTransport = searchTransportService; - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), + QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, + results.length(), exc -> {}); + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext, exc -> {}); + }, mockSearchPhaseContext); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -141,15 +148,19 @@ public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest } } }; + SearchPhaseController searchPhaseController = searchPhaseController(); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); mockSearchPhaseContext.searchTransport = searchTransportService; - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), + QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, + results.length(), exc -> {}); + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext, exc -> {}); + }, mockSearchPhaseContext); assertEquals("dfs_query", phase.getName()); phase.run(); mockSearchPhaseContext.assertNoFailure(); @@ -198,15 +209,19 @@ public void sendExecuteQuery(Transport.Connection connection, QuerySearchRequest } } }; + SearchPhaseController searchPhaseController = searchPhaseController(); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(2); mockSearchPhaseContext.searchTransport = searchTransportService; - DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, searchPhaseController(), + QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.searchRequest, + results.length(), exc -> {}); + DfsQueryPhase phase = new DfsQueryPhase(results.asList(), null, consumer, (response) -> new SearchPhase("test") { @Override public void run() throws IOException { responseRef.set(response.results); } - }, mockSearchPhaseContext, exc -> {}); + }, mockSearchPhaseContext); assertEquals("dfs_query", phase.getName()); expectThrows(UncheckedIOException.class, phase::run); assertTrue(mockSearchPhaseContext.releasedSearchContexts.isEmpty()); // phase execution will clean up on the contexts diff --git a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java index 9dc544091ae85..32d8e0d724686 100644 --- a/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/FetchSearchPhaseTests.java @@ -24,6 +24,8 @@ import org.apache.lucene.store.MockDirectoryWrapper; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.index.shard.ShardId; @@ -43,8 +45,6 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicInteger; -import static org.elasticsearch.action.search.SearchProgressListener.NOOP; - public class FetchSearchPhaseTests extends ESTestCase { public void testShortcutQueryAndFetchOptimization() { @@ -52,7 +52,8 @@ public void testShortcutQueryAndFetchOptimization() { writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 1, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 1, exc -> {}); boolean hasHits = randomBoolean(); final int numHits; if (hasHits) { @@ -96,7 +97,8 @@ public void testFetchTwoDocument() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx1, new SearchShardTarget("node1", new ShardId("test", "na", 0), @@ -157,7 +159,8 @@ public void testFailFetchOneDoc() { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx, @@ -220,7 +223,8 @@ public void testFetchDocsConcurrently() throws InterruptedException { SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(numHits); - QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), NOOP, + QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, mockSearchPhaseContext.getRequest(), numHits, exc -> {}); for (int i = 0; i < numHits; i++) { QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", i), @@ -279,7 +283,8 @@ public void testExceptionFailsPhase() { writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = randomIntBetween(2, 10); QuerySearchResult queryResult = new QuerySearchResult(new ShardSearchContextId("", 123), new SearchShardTarget("node1", new ShardId("test", "na", 0), @@ -337,7 +342,8 @@ public void testCleanupIrrelevantContexts() { // contexts that are not fetched s SearchPhaseController controller = new SearchPhaseController( writableRegistry(), s -> InternalAggregationTestCase.emptyReduceContextBuilder()); QueryPhaseResultConsumer results = controller.newSearchPhaseResults(EsExecutors.newDirectExecutorService(), - NOOP, mockSearchPhaseContext.getRequest(), 2, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + mockSearchPhaseContext.getRequest(), 2, exc -> {}); int resultSetSize = 1; final ShardSearchContextId ctx1 = new ShardSearchContextId(UUIDs.base64UUID(), 123); QuerySearchResult queryResult = new QuerySearchResult(ctx1, diff --git a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java index cf1a96dde422c..96e9fe7a61c1f 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java @@ -23,6 +23,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchShardTarget; @@ -131,6 +132,11 @@ public void executeNextPhase(SearchPhase currentPhase, SearchPhase nextPhase) { } } + @Override + public void addReleasable(Releasable releasable) { + // Noop + } + @Override public void execute(Runnable command) { command.run(); diff --git a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java index 2e9f7f7af6f41..f44a0cf292d5e 100644 --- a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java @@ -23,7 +23,8 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.elasticsearch.action.OriginalIndices; -import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -93,8 +94,9 @@ public void testProgressListenerExceptionsAreCaught() throws Exception { SearchRequest searchRequest = new SearchRequest("index"); searchRequest.setBatchedReduceSize(2); AtomicReference onPartialMergeFailure = new AtomicReference<>(); - QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, searchPhaseController, - searchProgressListener, writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), searchPhaseController, searchProgressListener, + writableRegistry(), 10, e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { curr.addSuppressed(prev); return curr; })); @@ -140,7 +142,7 @@ protected void onQueryResult(int shardIndex) { @Override protected void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) { + InternalAggregations aggs, int reducePhase) { onPartialReduce.incrementAndGet(); throw new UnsupportedOperationException(); } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java index 99181976fce65..676da3da9e63b 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchAsyncActionTests.java @@ -460,8 +460,7 @@ protected void executePhaseOnShard(SearchShardIterator shardIt, } @Override - protected SearchPhase getNextPhase(SearchPhaseResults results, - SearchPhaseContext context) { + protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { @Override public void run() { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index 7971a7d831106..2898e203a13a1 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -33,10 +33,10 @@ import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; -import org.elasticsearch.common.io.stream.DelayableWriteable; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.settings.Settings; @@ -45,7 +45,6 @@ import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor; -import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchHit; @@ -77,7 +76,6 @@ import org.junit.After; import org.junit.Before; -import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -95,7 +93,6 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; -import static org.elasticsearch.action.search.SearchProgressListener.NOOP; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -111,9 +108,9 @@ public class SearchPhaseControllerTests extends ESTestCase { @Override protected NamedWriteableRegistry writableRegistry() { - List entries = - new ArrayList<>(new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables()); - entries.add(new NamedWriteableRegistry.Entry(InternalAggregation.class, "throwing", InternalThrowing::new)); + List entries = new ArrayList<>( + new SearchModule(Settings.EMPTY, false, emptyList()).getNamedWriteables() + ); return new NamedWriteableRegistry(entries); } @@ -419,7 +416,8 @@ private void consumerTestCase(int numEmptyResponses) throws Exception { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.setBatchedReduceSize(bufferSize); ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, 3+numEmptyResponses, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, 3+numEmptyResponses, exc -> {}); if (numEmptyResponses == 0) { assertEquals(0, reductions.size()); } @@ -506,7 +504,8 @@ public void testConsumerConcurrently() throws Exception { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo"))); request.setBatchedReduceSize(bufferSize); ArraySearchPhaseResults consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); Thread[] threads = new Thread[expectedNumResults]; CountDownLatch latch = new CountDownLatch(expectedNumResults); @@ -556,7 +555,8 @@ public void testConsumerOnlyAggs() throws Exception { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { @@ -597,7 +597,8 @@ public void testConsumerOnlyHits() throws Exception { } request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); CountDownLatch latch = new CountDownLatch(expectedNumResults); for (int i = 0; i < expectedNumResults; i++) { @@ -640,7 +641,8 @@ public void testReduceTopNWithFromOffset() throws Exception { request.source(new SearchSourceBuilder().size(5).from(5)); request.setBatchedReduceSize(randomIntBetween(2, 4)); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, 4, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP + , request, 4, exc -> {}); int score = 100; CountDownLatch latch = new CountDownLatch(4); for (int i = 0; i < 4; i++) { @@ -678,7 +680,8 @@ public void testConsumerSortByField() throws Exception { int size = randomIntBetween(1, 10); request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); SortField[] sortFields = {new SortField("field", SortField.Type.INT, true)}; DocValueFormat[] docValueFormats = {DocValueFormat.RAW}; @@ -716,7 +719,8 @@ public void testConsumerFieldCollapsing() throws Exception { int size = randomIntBetween(5, 10); request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); SortField[] sortFields = {new SortField("field", SortField.Type.STRING)}; BytesRef a = new BytesRef("a"); BytesRef b = new BytesRef("b"); @@ -757,7 +761,8 @@ public void testConsumerSuggestions() throws Exception { SearchRequest request = randomSearchRequest(); request.setBatchedReduceSize(bufferSize); QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), SearchProgressListener.NOOP, + request, expectedNumResults, exc -> {}); int maxScoreTerm = -1; int maxScorePhrase = -1; int maxScoreCompletion = -1; @@ -871,7 +876,7 @@ public void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Except @Override public void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggs, int reducePhase) { + InternalAggregations aggs, int reducePhase) { assertEquals(numReduceListener.incrementAndGet(), reducePhase); } @@ -883,7 +888,7 @@ public void onFinalReduce(List shards, TotalHits totalHits, Interna } }; QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - progressListener, request, expectedNumResults, exc -> {}); + new NoopCircuitBreaker(CircuitBreaker.REQUEST), progressListener, request, expectedNumResults, exc -> {}); AtomicInteger max = new AtomicInteger(); Thread[] threads = new Thread[expectedNumResults]; CountDownLatch latch = new CountDownLatch(expectedNumResults); @@ -932,7 +937,19 @@ public void onFinalReduce(List shards, TotalHits totalHits, Interna } } - public void testPartialMergeFailure() throws InterruptedException { + public void testPartialReduce() throws Exception { + for (int i = 0; i < 10; i++) { + testReduceCase(false); + } + } + + public void testPartialReduceWithFailure() throws Exception { + for (int i = 0; i < 10; i++) { + testReduceCase(true); + } + } + + private void testReduceCase(boolean shouldFail) throws Exception { int expectedNumResults = randomIntBetween(20, 200); int bufferSize = randomIntBetween(2, expectedNumResults - 1); SearchRequest request = new SearchRequest(); @@ -940,11 +957,16 @@ public void testPartialMergeFailure() throws InterruptedException { request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0)); request.setBatchedReduceSize(bufferSize); AtomicBoolean hasConsumedFailure = new AtomicBoolean(); + AssertingCircuitBreaker circuitBreaker = new AssertingCircuitBreaker(CircuitBreaker.REQUEST); + boolean shouldFailPartial = shouldFail && randomBoolean(); + if (shouldFailPartial) { + circuitBreaker.shouldBreak.set(true); + } QueryPhaseResultConsumer consumer = searchPhaseController.newSearchPhaseResults(fixedExecutor, - NOOP, request, expectedNumResults, exc -> hasConsumedFailure.set(true)); + circuitBreaker, SearchProgressListener.NOOP, + request, expectedNumResults, exc -> hasConsumedFailure.set(true)); CountDownLatch latch = new CountDownLatch(expectedNumResults); Thread[] threads = new Thread[expectedNumResults]; - int failedIndex = randomIntBetween(0, expectedNumResults-1); for (int i = 0; i < expectedNumResults; i++) { final int index = i; threads[index] = new Thread(() -> { @@ -955,7 +977,7 @@ public void testPartialMergeFailure() throws InterruptedException { new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Lucene.EMPTY_SCORE_DOCS), Float.NaN), new DocValueFormat[0]); InternalAggregations aggs = InternalAggregations.from( - Collections.singletonList(new InternalThrowing("test", (failedIndex == index), Collections.emptyMap()))); + Collections.singletonList(new InternalMax("test", 0d, DocValueFormat.RAW, Collections.emptyMap()))); result.aggregations(aggs); result.setShardIndex(index); result.size(1); @@ -967,65 +989,44 @@ public void testPartialMergeFailure() throws InterruptedException { threads[i].join(); } latch.await(); - IllegalStateException exc = expectThrows(IllegalStateException.class, () -> consumer.reduce()); - if (exc.getMessage().contains("partial reduce")) { - assertTrue(hasConsumedFailure.get()); + if (shouldFail) { + if (shouldFailPartial == false) { + circuitBreaker.shouldBreak.set(true); + } + CircuitBreakingException exc = expectThrows(CircuitBreakingException.class, () -> consumer.reduce()); + assertEquals(shouldFailPartial, hasConsumedFailure.get()); + assertThat(exc.getMessage(), containsString("")); + circuitBreaker.shouldBreak.set(false); } else { - assertThat(exc.getMessage(), containsString("final reduce")); + SearchPhaseController.ReducedQueryPhase phase = consumer.reduce(); } + consumer.close(); + assertThat(circuitBreaker.allocated, equalTo(0L)); } - private static class InternalThrowing extends InternalAggregation { - private final boolean shouldThrow; - - protected InternalThrowing(String name, boolean shouldThrow, Map metadata) { - super(name, metadata); - this.shouldThrow = shouldThrow; - } + private static class AssertingCircuitBreaker extends NoopCircuitBreaker { + private final AtomicBoolean shouldBreak = new AtomicBoolean(false); - protected InternalThrowing(StreamInput in) throws IOException { - super(in); - this.shouldThrow = in.readBoolean(); - } + private volatile long allocated; - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeBoolean(shouldThrow); + AssertingCircuitBreaker(String name) { + super(name); } @Override - public InternalAggregation reduce(List aggregations, ReduceContext reduceContext) { - if (aggregations.stream() - .map(agg -> (InternalThrowing) agg) - .anyMatch(agg -> agg.shouldThrow)) { - if (reduceContext.isFinalReduce()) { - throw new IllegalStateException("final reduce"); - } else { - throw new IllegalStateException("partial reduce"); - } + public double addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + assert bytes >= 0; + if (shouldBreak.get()) { + throw new CircuitBreakingException(label, getDurability()); } - return new InternalThrowing(name, false, metadata); - } - - @Override - protected boolean mustReduceOnSingleInternalAgg() { - return true; + allocated += bytes; + return allocated; } @Override - public Object getProperty(List path) { - return null; - } - - @Override - public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { - throw new IllegalStateException("not implemented"); - } - - @Override - public String getWriteableName() { - return "throwing"; + public long addWithoutBreaking(long bytes) { + allocated += bytes; + return allocated; } } - } diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index cbdc5b56c85b8..9c1d4bf3448df 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -29,6 +29,8 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.GroupShardsIterator; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -51,6 +53,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -144,15 +147,19 @@ public void sendExecuteQuery(Transport.Connection connection, ShardSearchRequest searchRequest.source().collapse(new CollapseBuilder("collapse_field")); } searchRequest.allowPartialSearchResults(false); + Executor executor = EsExecutors.newDirectExecutorService(); SearchPhaseController controller = new SearchPhaseController( writableRegistry(), r -> InternalAggregationTestCase.emptyReduceContextBuilder()); SearchTask task = new SearchTask(0, "n/a", "n/a", () -> "test", null, Collections.emptyMap()); + QueryPhaseResultConsumer resultConsumer = new QueryPhaseResultConsumer(searchRequest, executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), controller, task.getProgressListener(), writableRegistry(), + shardsIter.size(), exc -> {}); SearchQueryThenFetchAsyncAction action = new SearchQueryThenFetchAsyncAction(logger, searchTransportService, (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", new AliasFilter(null, Strings.EMPTY_ARRAY)), - Collections.emptyMap(), Collections.emptyMap(), controller, EsExecutors.newDirectExecutorService(), searchRequest, - null, shardsIter, timeProvider, null, task, - SearchResponse.Clusters.EMPTY, exc -> {}) { + Collections.emptyMap(), Collections.emptyMap(), controller, executor, + resultConsumer, searchRequest, null, shardsIter, timeProvider, null, + task, SearchResponse.Clusters.EMPTY) { @Override protected SearchPhase getNextPhase(SearchPhaseResults results, SearchPhaseContext context) { return new SearchPhase("test") { diff --git a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionSingleNodeTests.java b/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionSingleNodeTests.java deleted file mode 100644 index 7e27aaac59ecb..0000000000000 --- a/server/src/test/java/org/elasticsearch/action/search/TransportSearchActionSingleNodeTests.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.action.search; - -import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.action.support.IndicesOptions; -import org.elasticsearch.action.support.WriteRequest; -import org.elasticsearch.common.Strings; -import org.elasticsearch.index.query.RangeQueryBuilder; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.search.SearchHit; -import org.elasticsearch.search.aggregations.Aggregations; -import org.elasticsearch.search.aggregations.bucket.terms.LongTerms; -import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregationBuilder; -import org.elasticsearch.search.aggregations.support.ValueType; -import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.test.ESSingleNodeTestCase; - -public class TransportSearchActionSingleNodeTests extends ESSingleNodeTestCase { - - public void testLocalClusterAlias() { - long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); - IndexRequest indexRequest = new IndexRequest("test"); - indexRequest.id("1"); - indexRequest.source("field", "value"); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, - "local", nowInMillis, randomBoolean()); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(1, searchResponse.getHits().getTotalHits().value); - SearchHit[] hits = searchResponse.getHits().getHits(); - assertEquals(1, hits.length); - SearchHit hit = hits[0]; - assertEquals("local", hit.getClusterAlias()); - assertEquals("test", hit.getIndex()); - assertEquals("1", hit.getId()); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), Strings.EMPTY_ARRAY, - "", nowInMillis, randomBoolean()); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(1, searchResponse.getHits().getTotalHits().value); - SearchHit[] hits = searchResponse.getHits().getHits(); - assertEquals(1, hits.length); - SearchHit hit = hits[0]; - assertEquals("", hit.getClusterAlias()); - assertEquals("test", hit.getIndex()); - assertEquals("1", hit.getId()); - } - } - - public void testAbsoluteStartMillis() { - { - IndexRequest indexRequest = new IndexRequest("test-1970.01.01"); - indexRequest.id("1"); - indexRequest.source("date", "1970-01-01"); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - } - { - IndexRequest indexRequest = new IndexRequest("test-1982.01.01"); - indexRequest.id("1"); - indexRequest.source("date", "1982-01-01"); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - } - { - SearchRequest searchRequest = new SearchRequest(); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - } - { - SearchRequest searchRequest = new SearchRequest(""); - searchRequest.indicesOptions(IndicesOptions.fromOptions(true, true, true, true)); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(0, searchResponse.getTotalShards()); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), - Strings.EMPTY_ARRAY, "", 0, randomBoolean()); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), - Strings.EMPTY_ARRAY, "", 0, randomBoolean()); - searchRequest.indices(""); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(1, searchResponse.getHits().getTotalHits().value); - assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(new SearchRequest(), - Strings.EMPTY_ARRAY, "", 0, randomBoolean()); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - RangeQueryBuilder rangeQuery = new RangeQueryBuilder("date"); - rangeQuery.gte("1970-01-01"); - rangeQuery.lt("1982-01-01"); - sourceBuilder.query(rangeQuery); - searchRequest.source(sourceBuilder); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(1, searchResponse.getHits().getTotalHits().value); - assertEquals("test-1970.01.01", searchResponse.getHits().getHits()[0].getIndex()); - } - } - - public void testFinalReduce() { - long nowInMillis = randomLongBetween(0, Long.MAX_VALUE); - { - IndexRequest indexRequest = new IndexRequest("test"); - indexRequest.id("1"); - indexRequest.source("price", 10); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - } - { - IndexRequest indexRequest = new IndexRequest("test"); - indexRequest.id("2"); - indexRequest.source("price", 100); - IndexResponse indexResponse = client().index(indexRequest).actionGet(); - assertEquals(RestStatus.CREATED, indexResponse.status()); - } - client().admin().indices().prepareRefresh("test").get(); - - SearchRequest originalRequest = new SearchRequest(); - SearchSourceBuilder source = new SearchSourceBuilder(); - source.size(0); - originalRequest.source(source); - TermsAggregationBuilder terms = new TermsAggregationBuilder("terms").userValueTypeHint(ValueType.NUMERIC); - terms.field("price"); - terms.size(1); - source.aggregation(terms); - - { - SearchRequest searchRequest = randomBoolean() ? originalRequest : SearchRequest.subSearchRequest(originalRequest, - Strings.EMPTY_ARRAY, "remote", nowInMillis, true); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - Aggregations aggregations = searchResponse.getAggregations(); - LongTerms longTerms = aggregations.get("terms"); - assertEquals(1, longTerms.getBuckets().size()); - } - { - SearchRequest searchRequest = SearchRequest.subSearchRequest(originalRequest, - Strings.EMPTY_ARRAY, "remote", nowInMillis, false); - SearchResponse searchResponse = client().search(searchRequest).actionGet(); - assertEquals(2, searchResponse.getHits().getTotalHits().value); - Aggregations aggregations = searchResponse.getAggregations(); - LongTerms longTerms = aggregations.get("terms"); - assertEquals(2, longTerms.getBuckets().size()); - } - } -} diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java index 96072fca36b91..4fcf0255f6203 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/InternalAggregationsTests.java @@ -18,6 +18,8 @@ */ package org.elasticsearch.search.aggregations; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -44,6 +46,7 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonList; +import static org.hamcrest.Matchers.equalTo; public class InternalAggregationsTests extends ESTestCase { @@ -126,19 +129,32 @@ private static PipelineAggregator.PipelineTree randomPipelineTree() { public void testSerialization() throws Exception { InternalAggregations aggregations = createTestInstance(); - writeToAndReadFrom(aggregations, 0); + writeToAndReadFrom(aggregations, Version.CURRENT, 0); } - private void writeToAndReadFrom(InternalAggregations aggregations, int iteration) throws IOException { - try (BytesStreamOutput out = new BytesStreamOutput()) { - aggregations.writeTo(out); - try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(out.bytes().toBytesRef().bytes), registry)) { - InternalAggregations deserialized = InternalAggregations.readFrom(in); - assertEquals(aggregations.aggregations, deserialized.aggregations); - if (iteration < 2) { - writeToAndReadFrom(deserialized, iteration + 1); - } + public void testSerializedSize() throws Exception { + InternalAggregations aggregations = createTestInstance(); + assertThat(aggregations.getSerializedSize(), + equalTo((long) serialize(aggregations, Version.CURRENT).length)); + } + + private void writeToAndReadFrom(InternalAggregations aggregations, Version version, int iteration) throws IOException { + BytesRef serializedAggs = serialize(aggregations, version); + try (StreamInput in = new NamedWriteableAwareStreamInput(StreamInput.wrap(serializedAggs.bytes), registry)) { + in.setVersion(version); + InternalAggregations deserialized = InternalAggregations.readFrom(in); + assertEquals(aggregations.aggregations, deserialized.aggregations); + if (iteration < 2) { + writeToAndReadFrom(deserialized, version, iteration + 1); } } } + + private BytesRef serialize(InternalAggregations aggs, Version version) throws IOException { + try (BytesStreamOutput out = new BytesStreamOutput()) { + out.setVersion(version); + aggs.writeTo(out); + return out.bytes().toBytesRef(); + } + } } diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index affa437759cb3..5541c679b8f4c 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -1616,7 +1616,7 @@ clusterService, indicesService, threadPool, shardStateAction, mappingUpdatedActi SearchPhaseController searchPhaseController = new SearchPhaseController( writableRegistry(), searchService::aggReduceContextBuilder); actions.put(SearchAction.INSTANCE, - new TransportSearchAction(client, threadPool, transportService, searchService, + new TransportSearchAction(client, threadPool, new NoneCircuitBreakerService(), transportService, searchService, searchTransportService, searchPhaseController, clusterService, actionFilters, indexNameExpressionResolver, namedWriteableRegistry)); actions.put(RestoreSnapshotAction.INSTANCE, diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java index 6ae6655ef6b6b..7bd4553776169 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.search.SearchTask; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.client.Client; -import org.elasticsearch.common.io.stream.DelayableWriteable; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.InternalAggregation; @@ -391,7 +390,7 @@ protected void onListShards(List shards, List skipped, @Override public void onPartialReduce(List shards, TotalHits totalHits, - DelayableWriteable.Serialized aggregations, int reducePhase) { + InternalAggregations aggregations, int reducePhase) { // best effort to cancel expired tasks checkCancellation(); // The way that the MutableSearchResponse will build the aggs. @@ -401,16 +400,15 @@ public void onPartialReduce(List shards, TotalHits totalHits, reducedAggs = () -> null; } else { /* - * Keep a reference to the serialized form of the partially - * reduced aggs and reduce it on the fly when someone asks + * Keep a reference to the partially reduced aggs and reduce it on the fly when someone asks * for it. It's important that we wait until someone needs * the result so we don't perform the final reduce only to * throw it away. And it is important that we keep the reference - * to the serialized aggregations because SearchPhaseController + * to the aggregations because SearchPhaseController * *already* has that reference so we're not creating more garbage. */ reducedAggs = () -> - InternalAggregations.topLevelReduce(singletonList(aggregations.expand()), aggReduceContextSupplier.get()); + InternalAggregations.topLevelReduce(singletonList(aggregations), aggReduceContextSupplier.get()); } searchResponse.get().updatePartialResponse(shards.size(), totalHits, reducedAggs, reducePhase); } diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java index 01f57a07ee817..d06b47d9cf5d6 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java @@ -16,8 +16,6 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; -import org.elasticsearch.common.io.stream.DelayableWriteable; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.shard.ShardId; @@ -25,7 +23,6 @@ import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.BucketOrder; -import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; import org.elasticsearch.search.aggregations.bucket.terms.StringTerms; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -155,56 +152,14 @@ public void onFailure(Exception e) { latch.await(); } - public void testGetResponseFailureDuringReduction() throws InterruptedException { - AsyncSearchTask task = createAsyncSearchTask(); - task.getSearchProgressActionListener().onListShards(Collections.emptyList(), Collections.emptyList(), - SearchResponse.Clusters.EMPTY, false); - InternalAggregations aggs = InternalAggregations.from(Collections.singletonList(new StringTerms("name", BucketOrder.key(true), - BucketOrder.key(true), 1, 1, Collections.emptyMap(), DocValueFormat.RAW, 1, false, 1, Collections.emptyList(), 0))); - //providing an empty named writeable registry will make the expansion fail, hence the delayed reduction will fail too - //causing an exception when executing getResponse as part of the completion listener callback - DelayableWriteable.Serialized serializedAggs = DelayableWriteable.referencing(aggs) - .asSerialized(InternalAggregations::readFrom, new NamedWriteableRegistry(Collections.emptyList())); - task.getSearchProgressActionListener().onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO), - serializedAggs, 1); - AtomicReference response = new AtomicReference<>(); - CountDownLatch latch = new CountDownLatch(1); - task.addCompletionListener(new ActionListener() { - @Override - public void onResponse(AsyncSearchResponse asyncSearchResponse) { - assertTrue(response.compareAndSet(null, asyncSearchResponse)); - latch.countDown(); - } - - @Override - public void onFailure(Exception e) { - throw new AssertionError("onFailure should not be called"); - } - }, TimeValue.timeValueMillis(10L)); - assertTrue(latch.await(1, TimeUnit.SECONDS)); - assertNotNull(response.get().getSearchResponse()); - assertEquals(0, response.get().getSearchResponse().getTotalShards()); - assertEquals(0, response.get().getSearchResponse().getSuccessfulShards()); - assertEquals(0, response.get().getSearchResponse().getFailedShards()); - assertThat(response.get().getFailure(), instanceOf(ElasticsearchException.class)); - assertEquals("Async search: error while reducing partial results", response.get().getFailure().getMessage()); - assertThat(response.get().getFailure().getCause(), instanceOf(IllegalArgumentException.class)); - assertEquals("Unknown NamedWriteable category [" + InternalAggregation.class.getName() + "]", - response.get().getFailure().getCause().getMessage()); - } - public void testWithFailureAndGetResponseFailureDuringReduction() throws InterruptedException { AsyncSearchTask task = createAsyncSearchTask(); task.getSearchProgressActionListener().onListShards(Collections.emptyList(), Collections.emptyList(), SearchResponse.Clusters.EMPTY, false); InternalAggregations aggs = InternalAggregations.from(Collections.singletonList(new StringTerms("name", BucketOrder.key(true), BucketOrder.key(true), 1, 1, Collections.emptyMap(), DocValueFormat.RAW, 1, false, 1, Collections.emptyList(), 0))); - //providing an empty named writeable registry will make the expansion fail, hence the delayed reduction will fail too - //causing an exception when executing getResponse as part of the completion listener callback - DelayableWriteable.Serialized serializedAggs = DelayableWriteable.referencing(aggs) - .asSerialized(InternalAggregations::readFrom, new NamedWriteableRegistry(Collections.emptyList())); task.getSearchProgressActionListener().onPartialReduce(Collections.emptyList(), new TotalHits(0, TotalHits.Relation.EQUAL_TO), - serializedAggs, 1); + aggs, 1); task.getSearchProgressActionListener().onFailure(new CircuitBreakingException("boom", CircuitBreaker.Durability.TRANSIENT)); AtomicReference response = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); @@ -229,9 +184,6 @@ public void onFailure(Exception e) { Exception failure = asyncSearchResponse.getFailure(); assertThat(failure, instanceOf(ElasticsearchException.class)); assertEquals("Async search: error while reducing partial results", failure.getMessage()); - assertThat(failure.getCause(), instanceOf(IllegalArgumentException.class)); - assertEquals("Unknown NamedWriteable category [" + InternalAggregation.class.getName() + - "]", failure.getCause().getMessage()); assertEquals(1, failure.getSuppressed().length); assertThat(failure.getSuppressed()[0], instanceOf(ElasticsearchException.class)); assertEquals("error while executing search", failure.getSuppressed()[0].getMessage()); From 1f45e249a884546d967d609ee26dd4852b0c5d5b Mon Sep 17 00:00:00 2001 From: jimczi Date: Thu, 24 Sep 2020 16:15:54 +0200 Subject: [PATCH 2/2] Fix TransportSearchIT#testCircuitBreakerReduceFail Ensures that the test always run with a memory circuit breaker. Relates #62223 --- .../elasticsearch/action/search/TransportSearchIT.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java index 72af251447103..e267903b42e97 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/TransportSearchIT.java @@ -113,6 +113,14 @@ public void process(FetchSubPhase.HitContext hitContext) { } } + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put("indices.breaker.request.type", "memory") + .build(); + } + @Override protected Collection> nodePlugins() { return Collections.singletonList(TestPlugin.class);