40
40
import static org .elasticsearch .compute .gen .Types .BOOLEAN_BLOCK ;
41
41
import static org .elasticsearch .compute .gen .Types .BYTES_REF ;
42
42
import static org .elasticsearch .compute .gen .Types .BYTES_REF_BLOCK ;
43
- import static org .elasticsearch .compute .gen .Types .BYTES_REF_VECTOR ;
44
43
import static org .elasticsearch .compute .gen .Types .DOUBLE_BLOCK ;
45
44
import static org .elasticsearch .compute .gen .Types .DRIVER_CONTEXT ;
46
45
import static org .elasticsearch .compute .gen .Types .EXPRESSION_EVALUATOR ;
@@ -62,23 +61,18 @@ public class EvaluatorImplementer {
62
61
private final TypeElement declarationType ;
63
62
private final ProcessFunction processFunction ;
64
63
private final ClassName implementation ;
65
- private final int executionCost ;
66
64
private final boolean processOutputsMultivalued ;
67
65
68
66
public EvaluatorImplementer (
69
67
Elements elements ,
70
68
javax .lang .model .util .Types types ,
71
69
ExecutableElement processFunction ,
72
70
String extraName ,
73
- int executionCost ,
74
71
List <TypeMirror > warnExceptions
75
72
) {
76
73
this .declarationType = (TypeElement ) processFunction .getEnclosingElement ();
77
74
this .processFunction = new ProcessFunction (elements , types , processFunction , warnExceptions );
78
- this .executionCost = executionCost ;
79
- if (executionCost < 0 ) {
80
- throw new IllegalArgumentException ("executionCost must be non-negative; got " + executionCost );
81
- }
75
+
82
76
this .implementation = ClassName .get (
83
77
elements .getPackageOf (declarationType ).toString (),
84
78
declarationType .getSimpleName () + extraName + "Evaluator"
@@ -198,132 +192,78 @@ private MethodSpec realEval(boolean blockStyle) {
198
192
});
199
193
200
194
processFunction .args .stream ().forEach (a -> a .createScratch (builder ));
201
- if (vectorize ) {
202
- realEvalWithVectorizedStyle (builder , resultDataType );
203
- } else {
204
- if (executionCost > 0 ) {
205
- builder .addStatement ("int accumulatedCost = 0" );
206
- }
207
- builder .beginControlFlow ("position: for (int p = 0; p < positionCount; p++)" );
208
- {
209
- if (blockStyle ) {
210
- if (processOutputsMultivalued == false ) {
211
- processFunction .args .stream ().forEach (a -> a .skipNull (builder ));
212
- } else {
213
- builder .addStatement ("boolean allBlocksAreNulls = true" );
214
- // allow block type inputs to be null
215
- processFunction .args .stream ().forEach (a -> {
216
- if (a instanceof StandardProcessFunctionArg as ) {
217
- as .skipNull (builder );
218
- } else if (a instanceof BlockProcessFunctionArg ab ) {
219
- builder .beginControlFlow ("if (!$N.isNull(p))" , ab .paramName (blockStyle ));
220
- {
221
- builder .addStatement ("allBlocksAreNulls = false" );
222
- }
223
- builder .endControlFlow ();
224
- }
225
- });
226
195
227
- builder .beginControlFlow ("if (allBlocksAreNulls)" );
228
- {
229
- builder .addStatement ("result.appendNull()" );
230
- builder .addStatement ("continue position" );
196
+ builder .beginControlFlow ("position: for (int p = 0; p < positionCount; p++)" );
197
+ {
198
+ if (blockStyle ) {
199
+ if (processOutputsMultivalued == false ) {
200
+ processFunction .args .stream ().forEach (a -> a .skipNull (builder ));
201
+ } else {
202
+ builder .addStatement ("boolean allBlocksAreNulls = true" );
203
+ // allow block type inputs to be null
204
+ processFunction .args .stream ().forEach (a -> {
205
+ if (a instanceof StandardProcessFunctionArg as ) {
206
+ as .skipNull (builder );
207
+ } else if (a instanceof BlockProcessFunctionArg ab ) {
208
+ builder .beginControlFlow ("if (!$N.isNull(p))" , ab .paramName (blockStyle ));
209
+ {
210
+ builder .addStatement ("allBlocksAreNulls = false" );
211
+ }
212
+ builder .endControlFlow ();
231
213
}
232
- builder .endControlFlow ();
233
- }
234
- }
235
- processFunction .args .stream ().forEach (a -> a .unpackValues (builder , blockStyle ));
236
-
237
- StringBuilder pattern = new StringBuilder ();
238
- List <Object > args = new ArrayList <>();
239
- pattern .append (processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(" );
240
- args .add (declarationType );
241
- args .add (processFunction .function .getSimpleName ());
242
- processFunction .args .stream ().forEach (a -> {
243
- if (args .size () > 2 ) {
244
- pattern .append (", " );
214
+ });
215
+
216
+ builder .beginControlFlow ("if (allBlocksAreNulls)" );
217
+ {
218
+ builder .addStatement ("result.appendNull()" );
219
+ builder .addStatement ("continue position" );
245
220
}
246
- a .buildInvocation (pattern , args , blockStyle );
247
- });
248
- pattern .append (")" );
249
- String builtPattern ;
250
- if (processFunction .builderArg == null ) {
251
- builtPattern = "result.$L(" + pattern + ")" ;
252
- args .add (0 , appendMethod (resultDataType ));
253
- } else {
254
- builtPattern = pattern .toString ();
255
- }
256
- if (processFunction .warnExceptions .isEmpty () == false ) {
257
- builder .beginControlFlow ("try" );
258
- }
259
- if (executionCost > 0 ) {
260
- addEarlyTerminationCheck (builder , executionCost );
261
- }
262
- builder .addStatement (builtPattern , args .toArray ());
263
-
264
- if (processFunction .warnExceptions .isEmpty () == false ) {
265
- String catchPattern = "catch ("
266
- + processFunction .warnExceptions .stream ().map (m -> "$T" ).collect (Collectors .joining (" | " ))
267
- + " e)" ;
268
- builder .nextControlFlow (catchPattern , processFunction .warnExceptions .stream ().map (m -> TypeName .get (m )).toArray ());
269
- builder .addStatement ("warnings().registerException(e)" );
270
- builder .addStatement ("result.appendNull()" );
271
221
builder .endControlFlow ();
272
222
}
273
223
}
274
- builder .endControlFlow ();
275
- }
276
- builder .addStatement ("return result.build()" );
277
- }
278
- builder .endControlFlow ();
224
+ processFunction .args .stream ().forEach (a -> a .unpackValues (builder , blockStyle ));
225
+
226
+ StringBuilder pattern = new StringBuilder ();
227
+ List <Object > args = new ArrayList <>();
228
+ pattern .append (processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(" );
229
+ args .add (declarationType );
230
+ args .add (processFunction .function .getSimpleName ());
231
+ processFunction .args .stream ().forEach (a -> {
232
+ if (args .size () > 2 ) {
233
+ pattern .append (", " );
234
+ }
235
+ a .buildInvocation (pattern , args , blockStyle );
236
+ });
237
+ pattern .append (")" );
238
+ String builtPattern ;
239
+ if (processFunction .builderArg == null ) {
240
+ builtPattern = vectorize ? "result.$L(p, " + pattern + ")" : "result.$L(" + pattern + ")" ;
241
+ args .add (0 , appendMethod (resultDataType ));
242
+ } else {
243
+ builtPattern = pattern .toString ();
244
+ }
245
+ if (processFunction .warnExceptions .isEmpty () == false ) {
246
+ builder .beginControlFlow ("try" );
247
+ }
279
248
280
- return builder .build ();
281
- }
249
+ builder .addStatement (builtPattern , args .toArray ());
282
250
283
- private void realEvalWithVectorizedStyle (MethodSpec .Builder builder , ClassName resultDataType ) {
284
- boolean checkEarlyTerminationPerRow = executionCost > 0
285
- && processFunction .args .stream ().anyMatch (a -> a .dataType (false ).equals (BYTES_REF_VECTOR ));
286
- if (checkEarlyTerminationPerRow ) {
287
- builder .addStatement ("int accumulatedCost = 0" );
288
- }
289
- builder .beginControlFlow ("position: for (int p = 0; p < positionCount; p++)" );
290
- {
291
- processFunction .args .forEach (a -> a .unpackValues (builder , false ));
292
- StringBuilder pattern = new StringBuilder ();
293
- List <Object > args = new ArrayList <>();
294
- pattern .append ("$T.$N(" );
295
- args .add (declarationType );
296
- args .add (processFunction .function .getSimpleName ());
297
- processFunction .args .forEach (a -> {
298
- if (args .size () > 2 ) {
299
- pattern .append (", " );
251
+ if (processFunction .warnExceptions .isEmpty () == false ) {
252
+ String catchPattern = "catch ("
253
+ + processFunction .warnExceptions .stream ().map (m -> "$T" ).collect (Collectors .joining (" | " ))
254
+ + " e)" ;
255
+ builder .nextControlFlow (catchPattern , processFunction .warnExceptions .stream ().map (m -> TypeName .get (m )).toArray ());
256
+ builder .addStatement ("warnings().registerException(e)" );
257
+ builder .addStatement ("result.appendNull()" );
258
+ builder .endControlFlow ();
300
259
}
301
- a .buildInvocation (pattern , args , false );
302
- });
303
- pattern .append (")" );
304
- String builtPattern ;
305
- if (processFunction .builderArg == null ) {
306
- builtPattern = "result.$L(p, " + pattern + ")" ;
307
- args .add (0 , appendMethod (resultDataType ));
308
- } else {
309
- builtPattern = pattern .toString ();
310
- }
311
- if (checkEarlyTerminationPerRow ) {
312
- addEarlyTerminationCheck (builder , executionCost );
313
260
}
314
- builder .addStatement (builtPattern , args .toArray ());
261
+ builder .endControlFlow ();
262
+ builder .addStatement ("return result.build()" );
315
263
}
316
264
builder .endControlFlow ();
317
- }
318
265
319
- static void addEarlyTerminationCheck (MethodSpec .Builder builder , int executionCost ) {
320
- builder .addStatement ("accumulatedCost += $L" , executionCost );
321
- builder .beginControlFlow ("if (accumulatedCost >= DriverContext.CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD)" );
322
- {
323
- builder .addStatement ("accumulatedCost = 0" );
324
- builder .addStatement ("driverContext.checkForEarlyTermination()" );
325
- }
326
- builder .endControlFlow ();
266
+ return builder .build ();
327
267
}
328
268
329
269
private static void skipNull (MethodSpec .Builder builder , String value ) {
0 commit comments