Skip to content

Check for early termination in Driver #118188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 15, 2025
5 changes: 5 additions & 0 deletions docs/changelog/118188.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 118188
summary: Check for early termination in Driver
area: ES|QL
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,13 @@ SubscribableListener<Void> run(TimeValue maxTime, int maxIterations, LongSupplie
long nextStatus = startTime + statusNanos;
int iter = 0;
while (true) {
IsBlockedResult isBlocked = runSingleLoopIteration();
IsBlockedResult isBlocked = Operator.NOT_BLOCKED;
try {
isBlocked = runSingleLoopIteration();
} catch (DriverEarlyTerminationException unused) {
closeEarlyFinishedOperators();
assert isFinished() : "not finished after early termination";
}
iter++;
if (isBlocked.listener().isDone() == false) {
updateStatus(nowSupplier.getAsLong() - startTime, iter, DriverStatus.Status.ASYNC, isBlocked.reason());
Expand Down Expand Up @@ -242,39 +248,59 @@ public void abort(Exception reason, ActionListener<Void> listener) {
}

private IsBlockedResult runSingleLoopIteration() {
ensureNotCancelled();
driverContext.checkForEarlyTermination();
boolean movedPage = false;

if (activeOperators.isEmpty() == false && activeOperators.getLast().isFinished() == false) {
for (int i = 0; i < activeOperators.size() - 1; i++) {
Operator op = activeOperators.get(i);
Operator nextOp = activeOperators.get(i + 1);
for (int i = 0; i < activeOperators.size() - 1; i++) {
Operator op = activeOperators.get(i);
Operator nextOp = activeOperators.get(i + 1);

// skip blocked operator
if (op.isBlocked().listener().isDone() == false) {
continue;
}
// skip blocked operator
if (op.isBlocked().listener().isDone() == false) {
continue;
}

if (op.isFinished() == false && nextOp.needsInput()) {
Page page = op.getOutput();
if (page == null) {
// No result, just move to the next iteration
} else if (page.getPositionCount() == 0) {
// Empty result, release any memory it holds immediately and move to the next iteration
if (op.isFinished() == false && nextOp.needsInput()) {
driverContext.checkForEarlyTermination();
Page page = op.getOutput();
if (page == null) {
// No result, just move to the next iteration
} else if (page.getPositionCount() == 0) {
// Empty result, release any memory it holds immediately and move to the next iteration
page.releaseBlocks();
} else {
// Non-empty result from the previous operation, move it to the next operation
try {
driverContext.checkForEarlyTermination();
} catch (DriverEarlyTerminationException | TaskCancelledException e) {
page.releaseBlocks();
} else {
// Non-empty result from the previous operation, move it to the next operation
nextOp.addInput(page);
movedPage = true;
throw e;
}
nextOp.addInput(page);
movedPage = true;
}
}

if (op.isFinished()) {
nextOp.finish();
}
if (op.isFinished()) {
driverContext.checkForEarlyTermination();
nextOp.finish();
}
}

closeEarlyFinishedOperators();

if (movedPage == false) {
return oneOf(
activeOperators.stream()
.map(Operator::isBlocked)
.filter(laf -> laf.listener().isDone() == false)
.collect(Collectors.toList())
);
}
return Operator.NOT_BLOCKED;
}

private void closeEarlyFinishedOperators() {
for (int index = activeOperators.size() - 1; index >= 0; index--) {
if (activeOperators.get(index).isFinished()) {
/*
Expand All @@ -300,16 +326,6 @@ private IsBlockedResult runSingleLoopIteration() {
break;
}
}

if (movedPage == false) {
return oneOf(
activeOperators.stream()
.map(Operator::isBlocked)
.filter(laf -> laf.listener().isDone() == false)
.collect(Collectors.toList())
);
}
return Operator.NOT_BLOCKED;
}

public void cancel(String reason) {
Expand All @@ -318,13 +334,6 @@ public void cancel(String reason) {
}
}

private void ensureNotCancelled() {
String reason = cancelReason.get();
if (reason != null) {
throw new TaskCancelledException(reason);
}
}

public static void start(
ThreadContext threadContext,
Executor executor,
Expand All @@ -335,19 +344,36 @@ public static void start(
driver.completionListener.addListener(listener);
if (driver.started.compareAndSet(false, true)) {
driver.updateStatus(0, 0, DriverStatus.Status.STARTING, "driver starting");
// Register a listener to an exchange sink to handle early completion scenarios:
// 1. When the query accumulates sufficient data (e.g., reaching the LIMIT).
// 2. When users abort the query but want to retain the current result.
// This allows the Driver to finish early without waiting for the scheduled task.
if (driver.activeOperators.isEmpty() == false) {
if (driver.activeOperators.getLast() instanceof ExchangeSinkOperator sinkOperator) {
sinkOperator.addCompletionListener(ActionListener.running(driver.scheduler::runPendingTasks));
}
}
initializeEarlyTerminationChecker(driver);
schedule(DEFAULT_TIME_BEFORE_YIELDING, maxIterations, threadContext, executor, driver, driver.completionListener);
}
}

private static void initializeEarlyTerminationChecker(Driver driver) {
// Register a listener to an exchange sink to handle early completion scenarios:
// 1. When the query accumulates sufficient data (e.g., reaching the LIMIT).
// 2. When users abort the query but want to retain the current result.
// This allows the Driver to finish early without waiting for the scheduled task.
final AtomicBoolean earlyFinished = new AtomicBoolean();
driver.driverContext.initializeEarlyTerminationChecker(() -> {
final String reason = driver.cancelReason.get();
if (reason != null) {
throw new TaskCancelledException(reason);
}
if (earlyFinished.get()) {
throw new DriverEarlyTerminationException("Exchange sink is closed");
}
});
if (driver.activeOperators.isEmpty() == false) {
if (driver.activeOperators.getLast() instanceof ExchangeSinkOperator sinkOperator) {
sinkOperator.addCompletionListener(ActionListener.running(() -> {
earlyFinished.set(true);
driver.scheduler.runPendingTasks();
}));
}
}
}

// Drains all active operators and closes them.
private void drainAndCloseOperators(@Nullable Exception e) {
Iterator<Operator> itr = activeOperators.iterator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public class DriverContext {

private final WarningsMode warningsMode;

private Runnable earlyTerminationChecker = () -> {};

public DriverContext(BigArrays bigArrays, BlockFactory blockFactory) {
this(bigArrays, blockFactory, WarningsMode.COLLECT);
}
Expand Down Expand Up @@ -175,6 +177,21 @@ public void removeAsyncAction() {
asyncActions.removeInstance();
}

/**
* Checks if the Driver associated with this DriverContext has been cancelled or early terminated.
*/
public void checkForEarlyTermination() {
earlyTerminationChecker.run();
}

/**
* Initializes the early termination or cancellation checker for this DriverContext.
* This method should be called when associating this DriverContext with a driver.
*/
public void initializeEarlyTerminationChecker(Runnable checker) {
this.earlyTerminationChecker = checker;
}

/**
* Evaluators should use this function to decide their warning behavior.
* @return an appropriate {@link WarningsMode}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.operator;

import org.elasticsearch.ElasticsearchException;

/**
* An exception indicates that a compute should be terminated early as the downstream pipeline has enough or no long requires more data.
*/
public final class DriverEarlyTerminationException extends ElasticsearchException {
public DriverEarlyTerminationException(String message) {
super(message);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.compute.data.BasicBlockTests;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BlockFactory;
import org.elasticsearch.compute.data.ElementType;
import org.elasticsearch.compute.data.Page;
Expand All @@ -40,6 +41,7 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.LongSupplier;

Expand Down Expand Up @@ -280,6 +282,49 @@ public Page getOutput() {
}
}

public void testEarlyTermination() {
DriverContext driverContext = driverContext();
ThreadPool threadPool = threadPool();
try {
int positions = between(1000, 5000);
List<Page> inPages = randomList(1, 100, () -> {
var block = driverContext.blockFactory().newConstantIntBlockWith(randomInt(), positions);
return new Page(block);
});
final var sourceOperator = new CannedSourceOperator(inPages.iterator());
final int maxAllowedRows = between(1, 100);
final AtomicInteger processedRows = new AtomicInteger(0);
var sinkHandler = new ExchangeSinkHandler(driverContext.blockFactory(), positions, System::currentTimeMillis);
var sinkOperator = new ExchangeSinkOperator(sinkHandler.createExchangeSink(), Function.identity());
final var delayOperator = new EvalOperator(driverContext.blockFactory(), new EvalOperator.ExpressionEvaluator() {
@Override
public Block eval(Page page) {
for (int i = 0; i < page.getPositionCount(); i++) {
driverContext.checkForEarlyTermination();
if (processedRows.incrementAndGet() >= maxAllowedRows) {
sinkHandler.fetchPageAsync(true, ActionListener.noop());
}
}
return driverContext.blockFactory().newConstantBooleanBlockWith(true, page.getPositionCount());
}

@Override
public void close() {

}
});
Driver driver = new Driver(driverContext, sourceOperator, List.of(delayOperator), sinkOperator, () -> {});
ThreadContext threadContext = threadPool.getThreadContext();
PlainActionFuture<Void> future = new PlainActionFuture<>();

Driver.start(threadContext, threadPool.executor("esql"), driver, between(1, 1000), future);
future.actionGet(30, TimeUnit.SECONDS);
assertThat(processedRows.get(), equalTo(maxAllowedRows));
} finally {
terminate(threadPool);
}
}

public void testResumeOnEarlyFinish() throws Exception {
DriverContext driverContext = driverContext();
ThreadPool threadPool = threadPool();
Expand Down

This file was deleted.

Loading