From e6d74a0f17d9451771da7acdbc2460172b74a039 Mon Sep 17 00:00:00 2001 From: Armin Braun Date: Mon, 10 Feb 2025 19:29:07 +0100 Subject: [PATCH] Relax some search interfaces to allow arbitrary cancellable tasks (#122188) An easy change we can split out of #121885 to make that shorter. --- .../elasticsearch/search/DefaultSearchContext.java | 10 +++++----- .../java/org/elasticsearch/search/SearchService.java | 11 ++++++----- .../terms/SignificantTermsAggregatorFactory.java | 4 ++-- .../search/internal/FilteredSearchContext.java | 6 +++--- .../elasticsearch/search/internal/SearchContext.java | 8 ++++---- .../elasticsearch/search/rank/RankSearchContext.java | 6 +++--- .../org/elasticsearch/index/SearchSlowLogTests.java | 3 ++- .../org/elasticsearch/search/MockSearchService.java | 10 +++++----- .../org/elasticsearch/test/TestSearchContext.java | 8 ++++---- 9 files changed, 34 insertions(+), 32 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index b3e8dc3b90dfe..06f23e69a3e74 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -21,7 +21,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; import org.apache.lucene.util.NumericUtils; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.cluster.routing.IndexRouting; import org.elasticsearch.common.lucene.search.Queries; @@ -77,6 +76,7 @@ import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import java.io.IOException; import java.io.UncheckedIOException; @@ -132,7 +132,7 @@ final class DefaultSearchContext extends SearchContext { private CollapseContext collapse; // filter for sliced scroll private SliceBuilder sliceBuilder; - private SearchShardTask task; + private CancellableTask task; private QueryPhaseRankShardContext queryPhaseRankShardContext; /** @@ -466,7 +466,7 @@ public void preProcess() { this.query = buildFilteredQuery(query); if (lowLevelCancellation) { searcher().addQueryCancellation(() -> { - final SearchShardTask task = getTask(); + final CancellableTask task = getTask(); if (task != null) { task.ensureNotCancelled(); } @@ -941,12 +941,12 @@ public void setProfilers(Profilers profilers) { } @Override - public void setTask(SearchShardTask task) { + public void setTask(CancellableTask task) { this.task = task; } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { return task; } diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index f4f063aff470b..a8eac9a41dbc3 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -128,6 +128,7 @@ import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.suggest.Suggest; import org.elasticsearch.search.suggest.completion.CompletionSuggestion; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.Scheduler; @@ -590,7 +591,7 @@ private void loadOrExecuteQueryPhase(final ShardSearchRequest request, final Sea } } - public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task, ActionListener listener) { + public void executeQueryPhase(ShardSearchRequest request, CancellableTask task, ActionListener listener) { ActionListener finalListener = maybeWrapListenerForStackTrace(listener, request.getChannelVersion(), threadPool); assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1 : "empty responses require more than one shard"; @@ -738,7 +739,7 @@ private static void runAsync( * It is the responsibility of the caller to ensure that the ref count is correctly decremented * when the object is no longer needed. */ - private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchShardTask task) throws Exception { + private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, CancellableTask task) throws Exception { final ReaderContext readerContext = createOrGetReaderContext(request); try ( Releasable scope = tracer.withScope(task); @@ -998,7 +999,7 @@ public void executeFetchPhase( }, wrapFailureListener(listener, readerContext, markAsUsed)); } - public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, ActionListener listener) { + public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, ActionListener listener) { final ReaderContext readerContext = findReaderContext(request.contextId(), request); final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest()); final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest)); @@ -1038,7 +1039,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A })); } - protected void checkCancelled(SearchShardTask task) { + protected void checkCancelled(CancellableTask task) { // check cancellation as early as possible, as it avoids opening up a Lucene reader on FrozenEngine try { task.ensureNotCancelled(); @@ -1169,7 +1170,7 @@ public void openReaderContext(ShardId shardId, TimeValue keepAlive, ActionListen protected SearchContext createContext( ReaderContext readerContext, ShardSearchRequest request, - SearchShardTask task, + CancellableTask task, ResultsType resultsType, boolean includeAggregations ) throws IOException { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java index 1ff7529bf3188..3ee0520520931 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/SignificantTermsAggregatorFactory.java @@ -13,7 +13,6 @@ import org.apache.lucene.search.MatchNoDocsQuery; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.index.query.QueryBuilder; @@ -38,6 +37,7 @@ import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.xcontent.ParseField; import java.io.IOException; @@ -128,7 +128,7 @@ private static SignificantTermsAggregatorSupplier bytesSupplier() { *

* Some searches that will never match can still fall through and we endup running query that will produce no results. * However even in that case we sometimes do expensive things like loading global ordinals. This method should prevent this. - * Note that if {@link org.elasticsearch.search.SearchService#executeQueryPhase(ShardSearchRequest, SearchShardTask, ActionListener)} + * Note that if {@link org.elasticsearch.search.SearchService#executeQueryPhase(ShardSearchRequest, CancellableTask, ActionListener)} * always do a can match then we don't need this code here. */ static boolean matchNoDocs(AggregationContext context, Aggregator parent) { diff --git a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java index 8c4f912c5988c..5bad06d08f96b 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/FilteredSearchContext.java @@ -12,7 +12,6 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; @@ -40,6 +39,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import java.util.List; @@ -422,12 +422,12 @@ public SearchExecutionContext getSearchExecutionContext() { } @Override - public void setTask(SearchShardTask task) { + public void setTask(CancellableTask task) { in.setTask(task); } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { return in.getTask(); } diff --git a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java index 45ac9fa06f399..90fa2b5ffc5c4 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/internal/SearchContext.java @@ -11,7 +11,6 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.core.Assertions; import org.elasticsearch.core.Nullable; @@ -48,6 +47,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.transport.LeakTracker; import java.io.IOException; @@ -90,7 +90,7 @@ public final List getCancellationChecks() { if (lowLevelCancellation()) { // This searching doesn't live beyond this phase, so we don't need to remove query cancellation Runnable c = () -> { - final SearchShardTask task = getTask(); + final CancellableTask task = getTask(); if (task != null) { task.ensureNotCancelled(); } @@ -100,9 +100,9 @@ public final List getCancellationChecks() { return timeoutRunnable == null ? List.of() : List.of(timeoutRunnable); } - public abstract void setTask(SearchShardTask task); + public abstract void setTask(CancellableTask task); - public abstract SearchShardTask getTask(); + public abstract CancellableTask getTask(); public abstract boolean isCancelled(); diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java index ad70e7d39aff8..951a9b0cf3520 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankSearchContext.java @@ -12,7 +12,6 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.cache.bitset.BitsetFilterCache; @@ -48,6 +47,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import java.util.List; @@ -211,12 +211,12 @@ public long getRelativeTimeInMillis() { /* ---- ALL METHODS ARE UNSUPPORTED BEYOND HERE ---- */ @Override - public void setTask(SearchShardTask task) { + public void setTask(CancellableTask task) { throw new UnsupportedOperationException(); } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { throw new UnsupportedOperationException(); } diff --git a/server/src/test/java/org/elasticsearch/index/SearchSlowLogTests.java b/server/src/test/java/org/elasticsearch/index/SearchSlowLogTests.java index 359118c7cb5a1..d4aec300c666b 100644 --- a/server/src/test/java/org/elasticsearch/index/SearchSlowLogTests.java +++ b/server/src/test/java/org/elasticsearch/index/SearchSlowLogTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESSingleNodeTestCase; import org.elasticsearch.test.TestSearchContext; @@ -93,7 +94,7 @@ public ShardSearchRequest request() { } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { return super.getTask(); } }; diff --git a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java index 778a6e3106f49..3600df134c206 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java +++ b/test/framework/src/main/java/org/elasticsearch/search/MockSearchService.java @@ -9,7 +9,6 @@ package org.elasticsearch.search; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.TimeValue; @@ -25,6 +24,7 @@ import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.rank.feature.RankFeatureShardPhase; +import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.telemetry.tracing.Tracer; import org.elasticsearch.threadpool.ThreadPool; @@ -48,7 +48,7 @@ public static class TestPlugin extends Plugin {} private Consumer onCreateSearchContext = context -> {}; - private Function onCheckCancelled = Function.identity(); + private Function onCheckCancelled = Function.identity(); /** Throw an {@link AssertionError} if there are still in-flight contexts. */ public static void assertNoInFlightContext() { @@ -138,7 +138,7 @@ public void setOnCreateSearchContext(Consumer onCreateSearchConte protected SearchContext createContext( ReaderContext readerContext, ShardSearchRequest request, - SearchShardTask task, + CancellableTask task, ResultsType resultsType, boolean includeAggregations ) throws IOException { @@ -160,12 +160,12 @@ public SearchContext createSearchContext(ShardSearchRequest request, TimeValue t return searchContext; } - public void setOnCheckCancelled(Function onCheckCancelled) { + public void setOnCheckCancelled(Function onCheckCancelled) { this.onCheckCancelled = onCheckCancelled; } @Override - protected void checkCancelled(SearchShardTask task) { + protected void checkCancelled(CancellableTask task) { super.checkCancelled(onCheckCancelled.apply(task)); } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java index 103cf1c15abc1..c46442485ff9e 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/elasticsearch/test/TestSearchContext.java @@ -11,7 +11,6 @@ import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.Query; import org.apache.lucene.search.TotalHits; -import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexService; @@ -49,6 +48,7 @@ import org.elasticsearch.search.rescore.RescoreContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestionSearchContext; +import org.elasticsearch.tasks.CancellableTask; import java.util.Collections; import java.util.HashMap; @@ -67,7 +67,7 @@ public class TestSearchContext extends SearchContext { ParsedQuery postFilter; Query query; Float minScore; - SearchShardTask task; + CancellableTask task; SortAndFormats sort; boolean trackScores = false; int trackTotalHitsUpTo = SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO; @@ -506,12 +506,12 @@ public SearchExecutionContext getSearchExecutionContext() { } @Override - public void setTask(SearchShardTask task) { + public void setTask(CancellableTask task) { this.task = task; } @Override - public SearchShardTask getTask() { + public CancellableTask getTask() { return task; }