Skip to content

Commit 6d21d28

Browse files
committed
Check for early termination
1 parent b34e278 commit 6d21d28

File tree

234 files changed

+2898
-298
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

234 files changed

+2898
-298
lines changed

x-pack/plugin/esql/compute/ann/src/main/java/org/elasticsearch/compute/ann/Evaluator.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,10 @@
4242
* into a warning and turn into a null value.
4343
*/
4444
Class<? extends Exception>[] warnExceptions() default {};
45+
46+
/**
47+
* The estimated cost of evaluating one row of this evaluator. The Driver periodically checks for cancellation or early termination
48+
* after it has accumulated the cost over the threshold (2048). Increase this estimate for expensive evaluators.
49+
*/
50+
int estimateCost() default 1;
4551
}

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorImplementer.java

Lines changed: 99 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -61,18 +61,23 @@ public class EvaluatorImplementer {
6161
private final TypeElement declarationType;
6262
private final ProcessFunction processFunction;
6363
private final ClassName implementation;
64+
private final int estimateCost;
6465
private final boolean processOutputsMultivalued;
6566

6667
public EvaluatorImplementer(
6768
Elements elements,
6869
javax.lang.model.util.Types types,
6970
ExecutableElement processFunction,
7071
String extraName,
72+
int estimateCost,
7173
List<TypeMirror> warnExceptions
7274
) {
7375
this.declarationType = (TypeElement) processFunction.getEnclosingElement();
7476
this.processFunction = new ProcessFunction(elements, types, processFunction, warnExceptions);
75-
77+
this.estimateCost = estimateCost;
78+
if (estimateCost <= 0) {
79+
throw new IllegalArgumentException("estimateCost must be at least 1; got " + estimateCost);
80+
}
7681
this.implementation = ClassName.get(
7782
elements.getPackageOf(declarationType).toString(),
7883
declarationType.getSimpleName() + extraName + "Evaluator"
@@ -175,7 +180,7 @@ private MethodSpec realEval(boolean blockStyle) {
175180
boolean vectorize = false;
176181
if (blockStyle == false && processFunction.warnExceptions.isEmpty() && processOutputsMultivalued == false) {
177182
ClassName type = processFunction.resultDataType(false);
178-
vectorize = type.simpleName().startsWith("BytesRef") == false;
183+
vectorize = type.simpleName().startsWith("BytesRef") == false && processFunction.warnExceptions.isEmpty();
179184
}
180185

181186
TypeName builderType = vectorize ? vectorFixedBuilderType(elementType(resultDataType)) : builderType(resultDataType);
@@ -192,78 +197,125 @@ private MethodSpec realEval(boolean blockStyle) {
192197
});
193198

194199
processFunction.args.stream().forEach(a -> a.createScratch(builder));
195-
196-
builder.beginControlFlow("position: for (int p = 0; p < positionCount; p++)");
197-
{
198-
if (blockStyle) {
199-
if (processOutputsMultivalued == false) {
200-
processFunction.args.stream().forEach(a -> a.skipNull(builder));
201-
} else {
202-
builder.addStatement("boolean allBlocksAreNulls = true");
203-
// allow block type inputs to be null
204-
processFunction.args.stream().forEach(a -> {
205-
if (a instanceof StandardProcessFunctionArg as) {
206-
as.skipNull(builder);
207-
} else if (a instanceof BlockProcessFunctionArg ab) {
208-
builder.beginControlFlow("if (!$N.isNull(p))", ab.paramName(blockStyle));
209-
{
210-
builder.addStatement("allBlocksAreNulls = false");
200+
if (vectorize) {
201+
realEvalWithVectorizedStyle(builder, resultDataType);
202+
} else {
203+
builder.addStatement("int accumulatedCost = 0");
204+
builder.beginControlFlow("position: for (int p = 0; p < positionCount; p++)");
205+
{
206+
if (blockStyle) {
207+
if (processOutputsMultivalued == false) {
208+
processFunction.args.stream().forEach(a -> a.skipNull(builder));
209+
} else {
210+
builder.addStatement("boolean allBlocksAreNulls = true");
211+
// allow block type inputs to be null
212+
processFunction.args.stream().forEach(a -> {
213+
if (a instanceof StandardProcessFunctionArg as) {
214+
as.skipNull(builder);
215+
} else if (a instanceof BlockProcessFunctionArg ab) {
216+
builder.beginControlFlow("if (!$N.isNull(p))", ab.paramName(blockStyle));
217+
{
218+
builder.addStatement("allBlocksAreNulls = false");
219+
}
220+
builder.endControlFlow();
211221
}
212-
builder.endControlFlow();
213-
}
214-
});
222+
});
215223

216-
builder.beginControlFlow("if (allBlocksAreNulls)");
217-
{
218-
builder.addStatement("result.appendNull()");
219-
builder.addStatement("continue position");
224+
builder.beginControlFlow("if (allBlocksAreNulls)");
225+
{
226+
builder.addStatement("result.appendNull()");
227+
builder.addStatement("continue position");
228+
}
229+
builder.endControlFlow();
220230
}
231+
}
232+
processFunction.args.stream().forEach(a -> a.unpackValues(builder, blockStyle));
233+
234+
StringBuilder pattern = new StringBuilder();
235+
List<Object> args = new ArrayList<>();
236+
pattern.append(processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(");
237+
args.add(declarationType);
238+
args.add(processFunction.function.getSimpleName());
239+
processFunction.args.stream().forEach(a -> {
240+
if (args.size() > 2) {
241+
pattern.append(", ");
242+
}
243+
a.buildInvocation(pattern, args, blockStyle);
244+
});
245+
pattern.append(")");
246+
String builtPattern;
247+
if (processFunction.builderArg == null) {
248+
builtPattern = "result.$L(" + pattern + ")";
249+
args.add(0, appendMethod(resultDataType));
250+
} else {
251+
builtPattern = pattern.toString();
252+
}
253+
if (processFunction.warnExceptions.isEmpty() == false) {
254+
builder.beginControlFlow("try");
255+
}
256+
builder.addStatement("accumulatedCost += " + estimateCost);
257+
builder.beginControlFlow("if (accumulatedCost >= DriverContext.CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD)");
258+
{
259+
builder.addStatement("accumulatedCost = 0");
260+
builder.addStatement("driverContext.checkForEarlyTermination()");
261+
}
262+
builder.endControlFlow();
263+
builder.addStatement(builtPattern, args.toArray());
264+
265+
if (processFunction.warnExceptions.isEmpty() == false) {
266+
String catchPattern = "catch ("
267+
+ processFunction.warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | "))
268+
+ " e)";
269+
builder.nextControlFlow(catchPattern, processFunction.warnExceptions.stream().map(m -> TypeName.get(m)).toArray());
270+
builder.addStatement("warnings().registerException(e)");
271+
builder.addStatement("result.appendNull()");
221272
builder.endControlFlow();
222273
}
223274
}
224-
processFunction.args.stream().forEach(a -> a.unpackValues(builder, blockStyle));
275+
builder.endControlFlow();
276+
}
277+
builder.addStatement("return result.build()");
278+
}
279+
builder.endControlFlow();
225280

281+
return builder.build();
282+
}
283+
284+
private void realEvalWithVectorizedStyle(MethodSpec.Builder builder, ClassName resultDataType) {
285+
builder.addComment("generate a tight loop to allow vectorization");
286+
builder.addStatement("int maxBatchSize = Math.max(DriverContext.CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD / $L, 1)", estimateCost);
287+
builder.beginControlFlow("for (int start = 0; start < positionCount; )");
288+
{
289+
builder.addStatement("int end = start + Math.min(positionCount - start, maxBatchSize)");
290+
builder.addStatement("driverContext.checkForEarlyTermination()");
291+
builder.beginControlFlow("for (int p = start; p < end; p++)");
292+
{
293+
processFunction.args.forEach(a -> a.unpackValues(builder, false));
226294
StringBuilder pattern = new StringBuilder();
227295
List<Object> args = new ArrayList<>();
228-
pattern.append(processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(");
296+
pattern.append("$T.$N(");
229297
args.add(declarationType);
230298
args.add(processFunction.function.getSimpleName());
231-
processFunction.args.stream().forEach(a -> {
299+
processFunction.args.forEach(a -> {
232300
if (args.size() > 2) {
233301
pattern.append(", ");
234302
}
235-
a.buildInvocation(pattern, args, blockStyle);
303+
a.buildInvocation(pattern, args, false);
236304
});
237305
pattern.append(")");
238306
String builtPattern;
239307
if (processFunction.builderArg == null) {
240-
builtPattern = vectorize ? "result.$L(p, " + pattern + ")" : "result.$L(" + pattern + ")";
308+
builtPattern = "result.$L(p, " + pattern + ")";
241309
args.add(0, appendMethod(resultDataType));
242310
} else {
243311
builtPattern = pattern.toString();
244312
}
245-
if (processFunction.warnExceptions.isEmpty() == false) {
246-
builder.beginControlFlow("try");
247-
}
248-
249313
builder.addStatement(builtPattern, args.toArray());
250-
251-
if (processFunction.warnExceptions.isEmpty() == false) {
252-
String catchPattern = "catch ("
253-
+ processFunction.warnExceptions.stream().map(m -> "$T").collect(Collectors.joining(" | "))
254-
+ " e)";
255-
builder.nextControlFlow(catchPattern, processFunction.warnExceptions.stream().map(m -> TypeName.get(m)).toArray());
256-
builder.addStatement("warnings().registerException(e)");
257-
builder.addStatement("result.appendNull()");
258-
builder.endControlFlow();
259-
}
260314
}
261315
builder.endControlFlow();
262-
builder.addStatement("return result.build()");
316+
builder.addStatement("start = end");
263317
}
264318
builder.endControlFlow();
265-
266-
return builder.build();
267319
}
268320

269321
private static void skipNull(MethodSpec.Builder builder, String value) {

x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/EvaluatorProcessor.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ public boolean process(Set<? extends TypeElement> set, RoundEnvironment roundEnv
8282
env.getTypeUtils(),
8383
(ExecutableElement) evaluatorMethod,
8484
evaluatorAnn.extraName(),
85+
evaluatorAnn.estimateCost(),
8586
warnExceptionsTypes
8687
).sourceFile(),
8788
env

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/Driver.java

Lines changed: 54 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ public Driver(
139139
DriverSleeps.empty()
140140
)
141141
);
142+
driverContext.initializeEarlyTerminationChecker(() -> {
143+
ensureNotCancelled();
144+
checkForEarlyTermination();
145+
});
142146
}
143147

144148
/**
@@ -186,7 +190,13 @@ SubscribableListener<Void> run(TimeValue maxTime, int maxIterations, LongSupplie
186190
long nextStatus = startTime + statusNanos;
187191
int iter = 0;
188192
while (true) {
189-
IsBlockedResult isBlocked = runSingleLoopIteration();
193+
final IsBlockedResult isBlocked;
194+
try {
195+
isBlocked = runSingleLoopIteration();
196+
} catch (DriverEarlyTerminationException ignored) {
197+
// rerun to prune early terminated operators
198+
continue;
199+
}
190200
iter++;
191201
if (isBlocked.listener().isDone() == false) {
192202
updateStatus(nowSupplier.getAsLong() - startTime, iter, DriverStatus.Status.ASYNC, isBlocked.reason());
@@ -243,6 +253,33 @@ public void abort(Exception reason, ActionListener<Void> listener) {
243253

244254
private IsBlockedResult runSingleLoopIteration() {
245255
ensureNotCancelled();
256+
257+
for (int index = activeOperators.size() - 1; index >= 0; index--) {
258+
if (activeOperators.get(index).isFinished()) {
259+
/*
260+
* Close and remove this operator and all source operators in the
261+
* most paranoid possible way. Closing operators shouldn't throw,
262+
* but if it does, this will make sure we don't try to close any
263+
* that succeed twice.
264+
*/
265+
List<Operator> finishedOperators = this.activeOperators.subList(0, index + 1);
266+
Iterator<Operator> itr = finishedOperators.iterator();
267+
while (itr.hasNext()) {
268+
Operator op = itr.next();
269+
statusOfCompletedOperators.add(new DriverStatus.OperatorStatus(op.toString(), op.status()));
270+
op.close();
271+
itr.remove();
272+
}
273+
274+
// Finish the next operator, which is now the first operator.
275+
if (activeOperators.isEmpty() == false) {
276+
Operator newRootOperator = activeOperators.get(0);
277+
newRootOperator.finish();
278+
}
279+
break;
280+
}
281+
}
282+
246283
boolean movedPage = false;
247284

248285
for (int i = 0; i < activeOperators.size() - 1; i++) {
@@ -273,32 +310,6 @@ private IsBlockedResult runSingleLoopIteration() {
273310
}
274311
}
275312

276-
for (int index = activeOperators.size() - 1; index >= 0; index--) {
277-
if (activeOperators.get(index).isFinished()) {
278-
/*
279-
* Close and remove this operator and all source operators in the
280-
* most paranoid possible way. Closing operators shouldn't throw,
281-
* but if it does, this will make sure we don't try to close any
282-
* that succeed twice.
283-
*/
284-
List<Operator> finishedOperators = this.activeOperators.subList(0, index + 1);
285-
Iterator<Operator> itr = finishedOperators.iterator();
286-
while (itr.hasNext()) {
287-
Operator op = itr.next();
288-
statusOfCompletedOperators.add(new DriverStatus.OperatorStatus(op.toString(), op.status()));
289-
op.close();
290-
itr.remove();
291-
}
292-
293-
// Finish the next operator, which is now the first operator.
294-
if (activeOperators.isEmpty() == false) {
295-
Operator newRootOperator = activeOperators.get(0);
296-
newRootOperator.finish();
297-
}
298-
break;
299-
}
300-
}
301-
302313
if (movedPage == false) {
303314
return oneOf(
304315
activeOperators.stream()
@@ -332,6 +343,22 @@ private void ensureNotCancelled() {
332343
}
333344
}
334345

346+
private static class DriverEarlyTerminationException extends RuntimeException {
347+
348+
}
349+
350+
private void checkForEarlyTermination() throws DriverEarlyTerminationException {
351+
// If the last operation is finished, then we can discard all operations in the driver
352+
if (activeOperators.size() >= 2 && activeOperators.getLast().isFinished()) {
353+
for (int i = 0; i < activeOperators.size() - 1; i++) {
354+
Operator op = activeOperators.get(i);
355+
if (op.isFinished() == false) {
356+
throw new DriverEarlyTerminationException();
357+
}
358+
}
359+
}
360+
}
361+
335362
public static void start(
336363
ThreadContext threadContext,
337364
Executor executor,

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverContext.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ public class DriverContext {
6060

6161
private final WarningsMode warningsMode;
6262

63+
public static final int CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD = 2048;
64+
private static final Runnable NO_OP = () -> {};
65+
private Runnable earlyTerminationChecker = NO_OP;
66+
6367
public DriverContext(BigArrays bigArrays, BlockFactory blockFactory) {
6468
this(bigArrays, blockFactory, WarningsMode.COLLECT);
6569
}
@@ -175,6 +179,24 @@ public void removeAsyncAction() {
175179
asyncActions.removeInstance();
176180
}
177181

182+
/**
183+
* Checks if the Driver associated with this DriverContext has been cancelled or early terminated.
184+
*/
185+
public void checkForEarlyTermination() {
186+
earlyTerminationChecker.run();
187+
}
188+
189+
/**
190+
* Initializes the early termination or cancellation checker for this DriverContext.
191+
* This method should be called when associating this DriverContext with a driver.
192+
*/
193+
public void initializeEarlyTerminationChecker(Runnable checker) {
194+
if (this.earlyTerminationChecker != NO_OP) {
195+
throw new IllegalStateException("Early termination checker already initialized");
196+
}
197+
this.earlyTerminationChecker = checker;
198+
}
199+
178200
/**
179201
* Evaluators should use this function to decide their warning behavior.
180202
* @return an appropriate {@link WarningsMode}

0 commit comments

Comments
 (0)