Skip to content

Relax some search interfaces to allow arbitrary cancellable tasks (#122188) #126293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.cluster.routing.IndexRouting;
import org.elasticsearch.common.lucene.search.Queries;
Expand Down Expand Up @@ -77,6 +76,7 @@
import org.elasticsearch.search.slice.SliceBuilder;
import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.suggest.SuggestionSearchContext;
import org.elasticsearch.tasks.CancellableTask;

import java.io.IOException;
import java.io.UncheckedIOException;
Expand Down Expand Up @@ -132,7 +132,7 @@ final class DefaultSearchContext extends SearchContext {
private CollapseContext collapse;
// filter for sliced scroll
private SliceBuilder sliceBuilder;
private SearchShardTask task;
private CancellableTask task;
private QueryPhaseRankShardContext queryPhaseRankShardContext;

/**
Expand Down Expand Up @@ -466,7 +466,7 @@ public void preProcess() {
this.query = buildFilteredQuery(query);
if (lowLevelCancellation) {
searcher().addQueryCancellation(() -> {
final SearchShardTask task = getTask();
final CancellableTask task = getTask();
if (task != null) {
task.ensureNotCancelled();
}
Expand Down Expand Up @@ -941,12 +941,12 @@ public void setProfilers(Profilers profilers) {
}

@Override
public void setTask(SearchShardTask task) {
public void setTask(CancellableTask task) {
this.task = task;
}

@Override
public SearchShardTask getTask() {
public CancellableTask getTask() {
return task;
}

Expand Down
11 changes: 6 additions & 5 deletions server/src/main/java/org/elasticsearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
import org.elasticsearch.search.sort.SortBuilder;
import org.elasticsearch.search.suggest.Suggest;
import org.elasticsearch.search.suggest.completion.CompletionSuggestion;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.threadpool.Scheduler;
Expand Down Expand Up @@ -590,7 +591,7 @@ private void loadOrExecuteQueryPhase(final ShardSearchRequest request, final Sea
}
}

public void executeQueryPhase(ShardSearchRequest request, SearchShardTask task, ActionListener<SearchPhaseResult> listener) {
public void executeQueryPhase(ShardSearchRequest request, CancellableTask task, ActionListener<SearchPhaseResult> listener) {
ActionListener<SearchPhaseResult> finalListener = maybeWrapListenerForStackTrace(listener, request.getChannelVersion(), threadPool);
assert request.canReturnNullResponseIfMatchNoDocs() == false || request.numberOfShards() > 1
: "empty responses require more than one shard";
Expand Down Expand Up @@ -738,7 +739,7 @@ private static <T extends RefCounted> void runAsync(
* It is the responsibility of the caller to ensure that the ref count is correctly decremented
* when the object is no longer needed.
*/
private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, SearchShardTask task) throws Exception {
private SearchPhaseResult executeQueryPhase(ShardSearchRequest request, CancellableTask task) throws Exception {
final ReaderContext readerContext = createOrGetReaderContext(request);
try (
Releasable scope = tracer.withScope(task);
Expand Down Expand Up @@ -998,7 +999,7 @@ public void executeFetchPhase(
}, wrapFailureListener(listener, readerContext, markAsUsed));
}

public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, ActionListener<FetchSearchResult> listener) {
public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, ActionListener<FetchSearchResult> listener) {
final ReaderContext readerContext = findReaderContext(request.contextId(), request);
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
Expand Down Expand Up @@ -1038,7 +1039,7 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
}));
}

protected void checkCancelled(SearchShardTask task) {
protected void checkCancelled(CancellableTask task) {
// check cancellation as early as possible, as it avoids opening up a Lucene reader on FrozenEngine
try {
task.ensureNotCancelled();
Expand Down Expand Up @@ -1169,7 +1170,7 @@ public void openReaderContext(ShardId shardId, TimeValue keepAlive, ActionListen
protected SearchContext createContext(
ReaderContext readerContext,
ShardSearchRequest request,
SearchShardTask task,
CancellableTask task,
ResultsType resultsType,
boolean includeAggregations
) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.apache.lucene.search.MatchNoDocsQuery;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.index.query.QueryBuilder;
Expand All @@ -38,6 +37,7 @@
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.aggregations.support.ValuesSourceRegistry;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.xcontent.ParseField;

import java.io.IOException;
Expand Down Expand Up @@ -128,7 +128,7 @@ private static SignificantTermsAggregatorSupplier bytesSupplier() {
* <p>
* Some searches that will never match can still fall through and we endup running query that will produce no results.
* However even in that case we sometimes do expensive things like loading global ordinals. This method should prevent this.
* Note that if {@link org.elasticsearch.search.SearchService#executeQueryPhase(ShardSearchRequest, SearchShardTask, ActionListener)}
* Note that if {@link org.elasticsearch.search.SearchService#executeQueryPhase(ShardSearchRequest, CancellableTask, ActionListener)}
* always do a can match then we don't need this code here.
*/
static boolean matchNoDocs(AggregationContext context, Aggregator parent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.cache.bitset.BitsetFilterCache;
Expand Down Expand Up @@ -40,6 +39,7 @@
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.suggest.SuggestionSearchContext;
import org.elasticsearch.tasks.CancellableTask;

import java.util.List;

Expand Down Expand Up @@ -422,12 +422,12 @@ public SearchExecutionContext getSearchExecutionContext() {
}

@Override
public void setTask(SearchShardTask task) {
public void setTask(CancellableTask task) {
in.setTask(task);
}

@Override
public SearchShardTask getTask() {
public CancellableTask getTask() {
return in.getTask();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.core.Assertions;
import org.elasticsearch.core.Nullable;
Expand Down Expand Up @@ -48,6 +47,7 @@
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.suggest.SuggestionSearchContext;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.transport.LeakTracker;

import java.io.IOException;
Expand Down Expand Up @@ -90,7 +90,7 @@ public final List<Runnable> getCancellationChecks() {
if (lowLevelCancellation()) {
// This searching doesn't live beyond this phase, so we don't need to remove query cancellation
Runnable c = () -> {
final SearchShardTask task = getTask();
final CancellableTask task = getTask();
if (task != null) {
task.ensureNotCancelled();
}
Expand All @@ -100,9 +100,9 @@ public final List<Runnable> getCancellationChecks() {
return timeoutRunnable == null ? List.of() : List.of(timeoutRunnable);
}

public abstract void setTask(SearchShardTask task);
public abstract void setTask(CancellableTask task);

public abstract SearchShardTask getTask();
public abstract CancellableTask getTask();

public abstract boolean isCancelled();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.cache.bitset.BitsetFilterCache;
Expand Down Expand Up @@ -48,6 +47,7 @@
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.suggest.SuggestionSearchContext;
import org.elasticsearch.tasks.CancellableTask;

import java.util.List;

Expand Down Expand Up @@ -211,12 +211,12 @@ public long getRelativeTimeInMillis() {
/* ---- ALL METHODS ARE UNSUPPORTED BEYOND HERE ---- */

@Override
public void setTask(SearchShardTask task) {
public void setTask(CancellableTask task) {
throw new UnsupportedOperationException();
}

@Override
public SearchShardTask getTask() {
public CancellableTask getTask() {
throw new UnsupportedOperationException();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.TestSearchContext;
Expand Down Expand Up @@ -93,7 +94,7 @@ public ShardSearchRequest request() {
}

@Override
public SearchShardTask getTask() {
public CancellableTask getTask() {
return super.getTask();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

package org.elasticsearch.search;

import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.core.TimeValue;
Expand All @@ -25,6 +24,7 @@
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest;
import org.elasticsearch.search.rank.feature.RankFeatureShardPhase;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.threadpool.ThreadPool;

Expand All @@ -48,7 +48,7 @@ public static class TestPlugin extends Plugin {}

private Consumer<SearchContext> onCreateSearchContext = context -> {};

private Function<SearchShardTask, SearchShardTask> onCheckCancelled = Function.identity();
private Function<CancellableTask, CancellableTask> onCheckCancelled = Function.identity();

/** Throw an {@link AssertionError} if there are still in-flight contexts. */
public static void assertNoInFlightContext() {
Expand Down Expand Up @@ -138,7 +138,7 @@ public void setOnCreateSearchContext(Consumer<SearchContext> onCreateSearchConte
protected SearchContext createContext(
ReaderContext readerContext,
ShardSearchRequest request,
SearchShardTask task,
CancellableTask task,
ResultsType resultsType,
boolean includeAggregations
) throws IOException {
Expand All @@ -160,12 +160,12 @@ public SearchContext createSearchContext(ShardSearchRequest request, TimeValue t
return searchContext;
}

public void setOnCheckCancelled(Function<SearchShardTask, SearchShardTask> onCheckCancelled) {
public void setOnCheckCancelled(Function<CancellableTask, CancellableTask> onCheckCancelled) {
this.onCheckCancelled = onCheckCancelled;
}

@Override
protected void checkCancelled(SearchShardTask task) {
protected void checkCancelled(CancellableTask task) {
super.checkCancelled(onCheckCancelled.apply(task));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TotalHits;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexService;
Expand Down Expand Up @@ -49,6 +48,7 @@
import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.search.suggest.SuggestionSearchContext;
import org.elasticsearch.tasks.CancellableTask;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -67,7 +67,7 @@ public class TestSearchContext extends SearchContext {
ParsedQuery postFilter;
Query query;
Float minScore;
SearchShardTask task;
CancellableTask task;
SortAndFormats sort;
boolean trackScores = false;
int trackTotalHitsUpTo = SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO;
Expand Down Expand Up @@ -506,12 +506,12 @@ public SearchExecutionContext getSearchExecutionContext() {
}

@Override
public void setTask(SearchShardTask task) {
public void setTask(CancellableTask task) {
this.task = task;
}

@Override
public SearchShardTask getTask() {
public CancellableTask getTask() {
return task;
}

Expand Down