@@ -61,18 +61,23 @@ public class EvaluatorImplementer {
61
61
private final TypeElement declarationType ;
62
62
private final ProcessFunction processFunction ;
63
63
private final ClassName implementation ;
64
+ private final int estimateCost ;
64
65
private final boolean processOutputsMultivalued ;
65
66
66
67
public EvaluatorImplementer (
67
68
Elements elements ,
68
69
javax .lang .model .util .Types types ,
69
70
ExecutableElement processFunction ,
70
71
String extraName ,
72
+ int estimateCost ,
71
73
List <TypeMirror > warnExceptions
72
74
) {
73
75
this .declarationType = (TypeElement ) processFunction .getEnclosingElement ();
74
76
this .processFunction = new ProcessFunction (elements , types , processFunction , warnExceptions );
75
-
77
+ this .estimateCost = estimateCost ;
78
+ if (estimateCost <= 0 ) {
79
+ throw new IllegalArgumentException ("estimateCost must be at least 1; got " + estimateCost );
80
+ }
76
81
this .implementation = ClassName .get (
77
82
elements .getPackageOf (declarationType ).toString (),
78
83
declarationType .getSimpleName () + extraName + "Evaluator"
@@ -175,7 +180,7 @@ private MethodSpec realEval(boolean blockStyle) {
175
180
boolean vectorize = false ;
176
181
if (blockStyle == false && processFunction .warnExceptions .isEmpty () && processOutputsMultivalued == false ) {
177
182
ClassName type = processFunction .resultDataType (false );
178
- vectorize = type .simpleName ().startsWith ("BytesRef" ) == false ;
183
+ vectorize = type .simpleName ().startsWith ("BytesRef" ) == false && processFunction . warnExceptions . isEmpty () ;
179
184
}
180
185
181
186
TypeName builderType = vectorize ? vectorFixedBuilderType (elementType (resultDataType )) : builderType (resultDataType );
@@ -192,78 +197,125 @@ private MethodSpec realEval(boolean blockStyle) {
192
197
});
193
198
194
199
processFunction .args .stream ().forEach (a -> a .createScratch (builder ));
195
-
196
- builder .beginControlFlow ("position: for (int p = 0; p < positionCount; p++)" );
197
- {
198
- if (blockStyle ) {
199
- if (processOutputsMultivalued == false ) {
200
- processFunction .args .stream ().forEach (a -> a .skipNull (builder ));
201
- } else {
202
- builder .addStatement ("boolean allBlocksAreNulls = true" );
203
- // allow block type inputs to be null
204
- processFunction .args .stream ().forEach (a -> {
205
- if (a instanceof StandardProcessFunctionArg as ) {
206
- as .skipNull (builder );
207
- } else if (a instanceof BlockProcessFunctionArg ab ) {
208
- builder .beginControlFlow ("if (!$N.isNull(p))" , ab .paramName (blockStyle ));
209
- {
210
- builder .addStatement ("allBlocksAreNulls = false" );
200
+ if (vectorize ) {
201
+ realEvalWithVectorizedStyle (builder , resultDataType );
202
+ } else {
203
+ builder .addStatement ("int accumulatedCost = 0" );
204
+ builder .beginControlFlow ("position: for (int p = 0; p < positionCount; p++)" );
205
+ {
206
+ if (blockStyle ) {
207
+ if (processOutputsMultivalued == false ) {
208
+ processFunction .args .stream ().forEach (a -> a .skipNull (builder ));
209
+ } else {
210
+ builder .addStatement ("boolean allBlocksAreNulls = true" );
211
+ // allow block type inputs to be null
212
+ processFunction .args .stream ().forEach (a -> {
213
+ if (a instanceof StandardProcessFunctionArg as ) {
214
+ as .skipNull (builder );
215
+ } else if (a instanceof BlockProcessFunctionArg ab ) {
216
+ builder .beginControlFlow ("if (!$N.isNull(p))" , ab .paramName (blockStyle ));
217
+ {
218
+ builder .addStatement ("allBlocksAreNulls = false" );
219
+ }
220
+ builder .endControlFlow ();
211
221
}
212
- builder .endControlFlow ();
213
- }
214
- });
222
+ });
215
223
216
- builder .beginControlFlow ("if (allBlocksAreNulls)" );
217
- {
218
- builder .addStatement ("result.appendNull()" );
219
- builder .addStatement ("continue position" );
224
+ builder .beginControlFlow ("if (allBlocksAreNulls)" );
225
+ {
226
+ builder .addStatement ("result.appendNull()" );
227
+ builder .addStatement ("continue position" );
228
+ }
229
+ builder .endControlFlow ();
220
230
}
231
+ }
232
+ processFunction .args .stream ().forEach (a -> a .unpackValues (builder , blockStyle ));
233
+
234
+ StringBuilder pattern = new StringBuilder ();
235
+ List <Object > args = new ArrayList <>();
236
+ pattern .append (processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(" );
237
+ args .add (declarationType );
238
+ args .add (processFunction .function .getSimpleName ());
239
+ processFunction .args .stream ().forEach (a -> {
240
+ if (args .size () > 2 ) {
241
+ pattern .append (", " );
242
+ }
243
+ a .buildInvocation (pattern , args , blockStyle );
244
+ });
245
+ pattern .append (")" );
246
+ String builtPattern ;
247
+ if (processFunction .builderArg == null ) {
248
+ builtPattern = "result.$L(" + pattern + ")" ;
249
+ args .add (0 , appendMethod (resultDataType ));
250
+ } else {
251
+ builtPattern = pattern .toString ();
252
+ }
253
+ if (processFunction .warnExceptions .isEmpty () == false ) {
254
+ builder .beginControlFlow ("try" );
255
+ }
256
+ builder .addStatement ("accumulatedCost += " + estimateCost );
257
+ builder .beginControlFlow ("if (accumulatedCost >= DriverContext.CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD)" );
258
+ {
259
+ builder .addStatement ("accumulatedCost = 0" );
260
+ builder .addStatement ("driverContext.checkForEarlyTermination()" );
261
+ }
262
+ builder .endControlFlow ();
263
+ builder .addStatement (builtPattern , args .toArray ());
264
+
265
+ if (processFunction .warnExceptions .isEmpty () == false ) {
266
+ String catchPattern = "catch ("
267
+ + processFunction .warnExceptions .stream ().map (m -> "$T" ).collect (Collectors .joining (" | " ))
268
+ + " e)" ;
269
+ builder .nextControlFlow (catchPattern , processFunction .warnExceptions .stream ().map (m -> TypeName .get (m )).toArray ());
270
+ builder .addStatement ("warnings().registerException(e)" );
271
+ builder .addStatement ("result.appendNull()" );
221
272
builder .endControlFlow ();
222
273
}
223
274
}
224
- processFunction .args .stream ().forEach (a -> a .unpackValues (builder , blockStyle ));
275
+ builder .endControlFlow ();
276
+ }
277
+ builder .addStatement ("return result.build()" );
278
+ }
279
+ builder .endControlFlow ();
225
280
281
+ return builder .build ();
282
+ }
283
+
284
+ private void realEvalWithVectorizedStyle (MethodSpec .Builder builder , ClassName resultDataType ) {
285
+ builder .addComment ("generate a tight loop to allow vectorization" );
286
+ builder .addStatement ("int maxBatchSize = Math.max(DriverContext.CHECK_FOR_EARLY_TERMINATION_COST_THRESHOLD / $L, 1)" , estimateCost );
287
+ builder .beginControlFlow ("for (int start = 0; start < positionCount; )" );
288
+ {
289
+ builder .addStatement ("int end = start + Math.min(positionCount - start, maxBatchSize)" );
290
+ builder .addStatement ("driverContext.checkForEarlyTermination()" );
291
+ builder .beginControlFlow ("for (int p = start; p < end; p++)" );
292
+ {
293
+ processFunction .args .forEach (a -> a .unpackValues (builder , false ));
226
294
StringBuilder pattern = new StringBuilder ();
227
295
List <Object > args = new ArrayList <>();
228
- pattern .append (processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(" );
296
+ pattern .append ("$T.$N(" );
229
297
args .add (declarationType );
230
298
args .add (processFunction .function .getSimpleName ());
231
- processFunction .args .stream (). forEach (a -> {
299
+ processFunction .args .forEach (a -> {
232
300
if (args .size () > 2 ) {
233
301
pattern .append (", " );
234
302
}
235
- a .buildInvocation (pattern , args , blockStyle );
303
+ a .buildInvocation (pattern , args , false );
236
304
});
237
305
pattern .append (")" );
238
306
String builtPattern ;
239
307
if (processFunction .builderArg == null ) {
240
- builtPattern = vectorize ? "result.$L(p, " + pattern + ")" : "result.$L( " + pattern + ")" ;
308
+ builtPattern = "result.$L(p, " + pattern + ")" ;
241
309
args .add (0 , appendMethod (resultDataType ));
242
310
} else {
243
311
builtPattern = pattern .toString ();
244
312
}
245
- if (processFunction .warnExceptions .isEmpty () == false ) {
246
- builder .beginControlFlow ("try" );
247
- }
248
-
249
313
builder .addStatement (builtPattern , args .toArray ());
250
-
251
- if (processFunction .warnExceptions .isEmpty () == false ) {
252
- String catchPattern = "catch ("
253
- + processFunction .warnExceptions .stream ().map (m -> "$T" ).collect (Collectors .joining (" | " ))
254
- + " e)" ;
255
- builder .nextControlFlow (catchPattern , processFunction .warnExceptions .stream ().map (m -> TypeName .get (m )).toArray ());
256
- builder .addStatement ("warnings().registerException(e)" );
257
- builder .addStatement ("result.appendNull()" );
258
- builder .endControlFlow ();
259
- }
260
314
}
261
315
builder .endControlFlow ();
262
- builder .addStatement ("return result.build() " );
316
+ builder .addStatement ("start = end " );
263
317
}
264
318
builder .endControlFlow ();
265
-
266
- return builder .build ();
267
319
}
268
320
269
321
private static void skipNull (MethodSpec .Builder builder , String value ) {
0 commit comments