Skip to content

Commit 97bb696

Browse files
committed
Check for early termination
1 parent b34e278 commit 97bb696

File tree

234 files changed

+1295
-282
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

234 files changed

+1295
-282
lines changed

x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Evaluator.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,10 @@
4242
* into a warning and turn into a null value.
4343
*/
4444
Class<? extends Exception>[] warnExceptions() default {};
45+
46+
/**
47+
* The estimated cost of evaluating one row of this evaluator. The Driver periodically checks for cancellation or early termination
48+
* after it has accumulated the cost over the threshold (2028). Increase this estimate for expensive evaluators.
49+
*/
50+
int estimateCost() default 1;
4551
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorImplementer.java

Lines changed: 90 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,20 @@ public class EvaluatorImplementer {
6161
private final TypeElement declarationType;
6262
private final ProcessFunction processFunction;
6363
private final ClassName implementation;
64+
private final int estimateCost;
6465
private final boolean processOutputsMultivalued;
6566

6667
public EvaluatorImplementer(
6768
Elements elements,
6869
javax.lang.model.util.Types types,
6970
ExecutableElement processFunction,
7071
String extraName,
72+
int estimateCost,
7173
List<TypeMirror> warnExceptions
7274
) {
7375
this.declarationType = (TypeElement) processFunction.getEnclosingElement();
7476
this.processFunction = new ProcessFunction(elements, types, processFunction, warnExceptions);
75-
77+
this.estimateCost = estimateCost;
7678
this.implementation = ClassName.get(
7779
elements.getPackageOf(declarationType).toString(),
7880
declarationType.getSimpleName() + extraName + "Evaluator"
@@ -175,7 +177,7 @@ private MethodSpec realEval(boolean blockStyle) {
175177
boolean vectorize = false;
176178
if (blockStyle == false && processFunction.warnExceptions.isEmpty() && processOutputsMultivalued == false) {
177179
ClassName type = processFunction.resultDataType(false);
178-
vectorize = type.simpleName().startsWith("BytesRef") == false;
180+
vectorize = type.simpleName().startsWith("BytesRef") == false && processFunction.warnExceptions.isEmpty();
179181
}
180182

181183
TypeName builderType = vectorize ? vectorFixedBuilderType(elementType(resultDataType)) : builderType(resultDataType);
@@ -192,78 +194,119 @@ private MethodSpec realEval(boolean blockStyle) {
192194
});
193195

194196
processFunction.args.stream().forEach(a -> a.createScratch(builder));
195-
196-
builder.beginControlFlow("position: for (int p = 0; p < positionCount; p++)");
197-
{
198-
if (blockStyle) {
199-
if (processOutputsMultivalued == false) {
200-
processFunction.args.stream().forEach(a -> a.skipNull(builder));
201-
} else {
202-
builder.addStatement("boolean allBlocksAreNulls = true");
203-
// allow block type inputs to be null
204-
processFunction.args.stream().forEach(a -> {
205-
if (a instanceof StandardProcessFunctionArg as) {
206-
as.skipNull(builder);
207-
} else if (a instanceof BlockProcessFunctionArg ab) {
208-
builder.beginControlFlow("if (!$N.isNull(p))", ab.paramName(blockStyle));
209-
{
210-
builder.addStatement("allBlocksAreNulls = false");
197+
if (vectorize) {
198+
realEvalWithVectorizedStyle(builder, resultDataType);
199+
} else {
200+
builder.beginControlFlow("position: for (int p = 0; p < positionCount; p++)");
201+
{
202+
if (blockStyle) {
203+
if (processOutputsMultivalued == false) {
204+
processFunction.args.stream().forEach(a -> a.skipNull(builder));
205+
} else {
206+
builder.addStatement("boolean allBlocksAreNulls = true");
207+
// allow block type inputs to be null
208+
processFunction.args.stream().forEach(a -> {
209+
if (a instanceof StandardProcessFunctionArg as) {
210+
as.skipNull(builder);
211+
} else if (a instanceof BlockProcessFunctionArg ab) {
212+
builder.beginControlFlow("if (!$N.isNull(p))", ab.paramName(blockStyle));
213+
{
214+
builder.addStatement("allBlocksAreNulls = false");
215+
}
216+
builder.endControlFlow();
211217
}
212-
builder.endControlFlow();
213-
}
214-
});
218+
});
215219

216-
builder.beginControlFlow("if (allBlocksAreNulls)");
217-
{
218-
builder.addStatement("result.appendNull()");
219-
builder.addStatement("continue position");
220+
builder.beginControlFlow("if (allBlocksAreNulls)");
221+
{
222+
builder.addStatement("result.appendNull()");
223+
builder.addStatement("continue position");
224+
}
225+
builder.endControlFlow();
220226
}
227+
}
228+
processFunction.args.stream().forEach(a -> a.unpackValues(builder, blockStyle));
229+
230+
StringBuilder pattern = new StringBuilder();
231+
List<Object> args = new ArrayList<>();
232+
pattern.append(processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(");
233+
args.add(declarationType);
234+
args.add(processFunction.function.getSimpleName());
235+
processFunction.args.stream().forEach(a -> {
236+
if (args.size() > 2) {
237+
pattern.append(", ");
238+
}
239+
a.buildInvocation(pattern, args, blockStyle);
240+
});
241+
pattern.append(")");
242+
String builtPattern;
243+
if (processFunction.builderArg == null) {
244+
builtPattern = "result.$L(" + pattern + ")";
245+
args.add(0, appendMethod(resultDataType));
246+
} else {
247+
builtPattern = pattern.toString();
248+
}
249+
if (processFunction.warnExceptions.isEmpty() == false) {
250+
builder.beginControlFlow("try");
251+
}
252+
253+
builder.addStatement("driverContext.maybeCheckForEarlyTermination(" + estimateCost + ")");
254+
builder.addStatement(builtPattern, args.toArray());
255+
256+
if (processFunction.warnExceptions.isEmpty() == false) {
257+
String catchPattern = "catch ("
258+
+ processFunction.warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | "))
259+
+ " e)";
260+
builder.nextControlFlow(catchPattern, processFunction.warnExceptions.stream().map(m -> TypeName.get(m)).toArray());
261+
builder.addStatement("warnings().registerException(e)");
262+
builder.addStatement("result.appendNull()");
221263
builder.endControlFlow();
222264
}
223265
}
224-
processFunction.args.stream().forEach(a -> a.unpackValues(builder, blockStyle));
266+
builder.endControlFlow();
267+
}
268+
builder.addStatement("return result.build()");
269+
}
270+
builder.endControlFlow();
225271

272+
return builder.build();
273+
}
274+
275+
private void realEvalWithVectorizedStyle(MethodSpec.Builder builder, ClassName resultDataType) {
276+
// generate the tight loop to allow vectorization
277+
builder.addStatement("final int maxBatchSize = DriverContext.estimateBatchSizeForEarlyTermination(" + estimateCost + ")");
278+
builder.beginControlFlow("for (int start = 0; start < positionCount; )");
279+
{
280+
builder.addStatement("int end = start + Math.min(positionCount - start, maxBatchSize)");
281+
builder.addStatement("driverContext.checkForEarlyTermination()");
282+
builder.beginControlFlow("for (int p = start; p < end; p++)");
283+
{
284+
processFunction.args.forEach(a -> a.unpackValues(builder, false));
226285
StringBuilder pattern = new StringBuilder();
227286
List<Object> args = new ArrayList<>();
228-
pattern.append(processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(");
287+
pattern.append("$T.$N(");
229288
args.add(declarationType);
230289
args.add(processFunction.function.getSimpleName());
231-
processFunction.args.stream().forEach(a -> {
290+
processFunction.args.forEach(a -> {
232291
if (args.size() > 2) {
233292
pattern.append(", ");
234293
}
235-
a.buildInvocation(pattern, args, blockStyle);
294+
a.buildInvocation(pattern, args, false);
236295
});
237296
pattern.append(")");
238297
String builtPattern;
239298
if (processFunction.builderArg == null) {
240-
builtPattern = vectorize ? "result.$L(p, " + pattern + ")" : "result.$L(" + pattern + ")";
299+
builtPattern = "result.$L(p, " + pattern + ")";
241300
args.add(0, appendMethod(resultDataType));
242301
} else {
243302
builtPattern = pattern.toString();
244303
}
245-
if (processFunction.warnExceptions.isEmpty() == false) {
246-
builder.beginControlFlow("try");
247-
}
248-
249304
builder.addStatement(builtPattern, args.toArray());
250-
251-
if (processFunction.warnExceptions.isEmpty() == false) {
252-
String catchPattern = "catch ("
253-
+ processFunction.warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | "))
254-
+ " e)";
255-
builder.nextControlFlow(catchPattern, processFunction.warnExceptions.stream().map(m -> TypeName.get(m)).toArray());
256-
builder.addStatement("warnings().registerException(e)");
257-
builder.addStatement("result.appendNull()");
258-
builder.endControlFlow();
259-
}
260305
}
261306
builder.endControlFlow();
262-
builder.addStatement("return result.build()");
307+
builder.addStatement("start = end");
263308
}
264309
builder.endControlFlow();
265-
266-
return builder.build();
267310
}
268311

269312
private static void skipNull(MethodSpec.Builder builder, String value) {

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorProcessor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
8282
env.getTypeUtils(),
8383
(ExecutableElement) evaluatorMethod,
8484
evaluatorAnn.extraName(),
85+
evaluatorAnn.estimateCost(),
8586
warnExceptionsTypes
8687
).sourceFile(),
8788
env

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ public Driver(
139139
DriverSleeps.empty()
140140
)
141141
);
142+
driverContext.initializeEarlyTerminationChecker(() -> {
143+
ensureNotCancelled();
144+
checkForEarlyTermination();
145+
});
142146
}
143147

144148
/**
@@ -186,7 +190,13 @@ SubscribableListener<Void> run(TimeValue maxTime, int maxIterations, LongSupplie
186190
long nextStatus = startTime + statusNanos;
187191
int iter = 0;
188192
while (true) {
189-
IsBlockedResult isBlocked = runSingleLoopIteration();
193+
final IsBlockedResult isBlocked;
194+
try {
195+
isBlocked = runSingleLoopIteration();
196+
} catch (DriverEarlyTerminationException ignored) {
197+
closeEarlyFinishedOperators();
198+
continue;
199+
}
190200
iter++;
191201
if (isBlocked.listener().isDone() == false) {
192202
updateStatus(nowSupplier.getAsLong() - startTime, iter, DriverStatus.Status.ASYNC, isBlocked.reason());
@@ -273,6 +283,20 @@ private IsBlockedResult runSingleLoopIteration() {
273283
}
274284
}
275285

286+
closeEarlyFinishedOperators();
287+
288+
if (movedPage == false) {
289+
return oneOf(
290+
activeOperators.stream()
291+
.map(Operator::isBlocked)
292+
.filter(laf -> laf.listener().isDone() == false)
293+
.collect(Collectors.toList())
294+
);
295+
}
296+
return Operator.NOT_BLOCKED;
297+
}
298+
299+
private void closeEarlyFinishedOperators() {
276300
for (int index = activeOperators.size() - 1; index >= 0; index--) {
277301
if (activeOperators.get(index).isFinished()) {
278302
/*
@@ -298,16 +322,6 @@ private IsBlockedResult runSingleLoopIteration() {
298322
break;
299323
}
300324
}
301-
302-
if (movedPage == false) {
303-
return oneOf(
304-
activeOperators.stream()
305-
.map(Operator::isBlocked)
306-
.filter(laf -> laf.listener().isDone() == false)
307-
.collect(Collectors.toList())
308-
);
309-
}
310-
return Operator.NOT_BLOCKED;
311325
}
312326

313327
public void cancel(String reason) {
@@ -332,6 +346,22 @@ private void ensureNotCancelled() {
332346
}
333347
}
334348

349+
private static class DriverEarlyTerminationException extends RuntimeException {
350+
351+
}
352+
353+
private void checkForEarlyTermination() throws DriverEarlyTerminationException {
354+
// If the last operation is finished, then we can discard all operations in the driver
355+
if (activeOperators.size() >= 2 && activeOperators.getLast().isFinished()) {
356+
for (int i = 0; i < activeOperators.size() - 1; i++) {
357+
Operator op = activeOperators.get(i);
358+
if (op.isFinished() == false) {
359+
throw new DriverEarlyTerminationException();
360+
}
361+
}
362+
}
363+
}
364+
335365
public static void start(
336366
ThreadContext threadContext,
337367
Executor executor,

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ public class DriverContext {
6060

6161
private final WarningsMode warningsMode;
6262

63+
private static final Runnable NO_OP = () -> {};
64+
65+
public static final int CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD = 2048;
66+
private Runnable earlyTerminationChecker = NO_OP;
67+
private int accumulatedCostForEarlyTermination;
68+
6369
public DriverContext(BigArrays bigArrays, BlockFactory blockFactory) {
6470
this(bigArrays, blockFactory, WarningsMode.COLLECT);
6571
}
@@ -175,6 +181,44 @@ public void removeAsyncAction() {
175181
asyncActions.removeInstance();
176182
}
177183

184+
/**
185+
* Accumulates the early termination cost and runs the early termination check if the accumulated cost passes the threshold.
186+
*/
187+
public void maybeCheckForEarlyTermination(int estimateCost) {
188+
accumulatedCostForEarlyTermination += estimateCost;
189+
if (accumulatedCostForEarlyTermination >= CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD) {
190+
checkForEarlyTermination();
191+
}
192+
}
193+
194+
/**
195+
* Checks if the Driver associated with this DriverContext has been cancelled or early terminated.
196+
*/
197+
public void checkForEarlyTermination() {
198+
assert earlyTerminationChecker != NO_OP : "cancellation or early termination checker not initialized";
199+
accumulatedCostForEarlyTermination = 0;
200+
earlyTerminationChecker.run();
201+
}
202+
203+
/**
204+
* Initializes the early termination or cancellation checker for this DriverContext.
205+
* This method should be called when associating this DriverContext with a driver.
206+
*/
207+
public void initializeEarlyTerminationChecker(Runnable checker) {
208+
assert earlyTerminationChecker == NO_OP : "cancellation or early termination checker already initialized";
209+
this.earlyTerminationChecker = checker;
210+
}
211+
212+
/**
213+
* Returns the estimated batch size for early termination based on the given estimate cost for each item.
214+
*/
215+
public static int estimateBatchSizeForEarlyTermination(int estimateCost) {
216+
if (estimateCost <= 0) {
217+
return CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD;
218+
}
219+
return Math.max(CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD / estimateCost, 1);
220+
}
221+
178222
/**
179223
* Evaluators should use this function to decide their warning behavior.
180224
* @return an appropriate {@link WarningsMode}

x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/evaluator/predicate/operator/logical/NotEvaluator.java

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)