@@ -61,18 +61,20 @@ public class EvaluatorImplementer {
61
61
private final TypeElement declarationType ;
62
62
private final ProcessFunction processFunction ;
63
63
private final ClassName implementation ;
64
+ private final int estimateCost ;
64
65
private final boolean processOutputsMultivalued ;
65
66
66
67
public EvaluatorImplementer (
67
68
Elements elements ,
68
69
javax .lang .model .util .Types types ,
69
70
ExecutableElement processFunction ,
70
71
String extraName ,
72
+ int estimateCost ,
71
73
List <TypeMirror > warnExceptions
72
74
) {
73
75
this .declarationType = (TypeElement ) processFunction .getEnclosingElement ();
74
76
this .processFunction = new ProcessFunction (elements , types , processFunction , warnExceptions );
75
-
77
+ this . estimateCost = estimateCost ;
76
78
this .implementation = ClassName .get (
77
79
elements .getPackageOf (declarationType ).toString (),
78
80
declarationType .getSimpleName () + extraName + "Evaluator"
@@ -175,7 +177,7 @@ private MethodSpec realEval(boolean blockStyle) {
175
177
boolean vectorize = false ;
176
178
if (blockStyle == false && processFunction .warnExceptions .isEmpty () && processOutputsMultivalued == false ) {
177
179
ClassName type = processFunction .resultDataType (false );
178
- vectorize = type .simpleName ().startsWith ("BytesRef" ) == false ;
180
+ vectorize = type .simpleName ().startsWith ("BytesRef" ) == false && processFunction . warnExceptions . isEmpty () ;
179
181
}
180
182
181
183
TypeName builderType = vectorize ? vectorFixedBuilderType (elementType (resultDataType )) : builderType (resultDataType );
@@ -192,78 +194,119 @@ private MethodSpec realEval(boolean blockStyle) {
192
194
});
193
195
194
196
processFunction .args .stream ().forEach (a -> a .createScratch (builder ));
195
-
196
- builder .beginControlFlow ("position: for (int p = 0; p < positionCount; p++)" );
197
- {
198
- if (blockStyle ) {
199
- if (processOutputsMultivalued == false ) {
200
- processFunction .args .stream ().forEach (a -> a .skipNull (builder ));
201
- } else {
202
- builder .addStatement ("boolean allBlocksAreNulls = true" );
203
- // allow block type inputs to be null
204
- processFunction .args .stream ().forEach (a -> {
205
- if (a instanceof StandardProcessFunctionArg as ) {
206
- as .skipNull (builder );
207
- } else if (a instanceof BlockProcessFunctionArg ab ) {
208
- builder .beginControlFlow ("if (!$N.isNull(p))" , ab .paramName (blockStyle ));
209
- {
210
- builder .addStatement ("allBlocksAreNulls = false" );
197
+ if (vectorize ) {
198
+ realEvalWithVectorizedStyle (builder , resultDataType );
199
+ } else {
200
+ builder .beginControlFlow ("position: for (int p = 0; p < positionCount; p++)" );
201
+ {
202
+ if (blockStyle ) {
203
+ if (processOutputsMultivalued == false ) {
204
+ processFunction .args .stream ().forEach (a -> a .skipNull (builder ));
205
+ } else {
206
+ builder .addStatement ("boolean allBlocksAreNulls = true" );
207
+ // allow block type inputs to be null
208
+ processFunction .args .stream ().forEach (a -> {
209
+ if (a instanceof StandardProcessFunctionArg as ) {
210
+ as .skipNull (builder );
211
+ } else if (a instanceof BlockProcessFunctionArg ab ) {
212
+ builder .beginControlFlow ("if (!$N.isNull(p))" , ab .paramName (blockStyle ));
213
+ {
214
+ builder .addStatement ("allBlocksAreNulls = false" );
215
+ }
216
+ builder .endControlFlow ();
211
217
}
212
- builder .endControlFlow ();
213
- }
214
- });
218
+ });
215
219
216
- builder .beginControlFlow ("if (allBlocksAreNulls)" );
217
- {
218
- builder .addStatement ("result.appendNull()" );
219
- builder .addStatement ("continue position" );
220
+ builder .beginControlFlow ("if (allBlocksAreNulls)" );
221
+ {
222
+ builder .addStatement ("result.appendNull()" );
223
+ builder .addStatement ("continue position" );
224
+ }
225
+ builder .endControlFlow ();
220
226
}
227
+ }
228
+ processFunction .args .stream ().forEach (a -> a .unpackValues (builder , blockStyle ));
229
+
230
+ StringBuilder pattern = new StringBuilder ();
231
+ List <Object > args = new ArrayList <>();
232
+ pattern .append (processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(" );
233
+ args .add (declarationType );
234
+ args .add (processFunction .function .getSimpleName ());
235
+ processFunction .args .stream ().forEach (a -> {
236
+ if (args .size () > 2 ) {
237
+ pattern .append (", " );
238
+ }
239
+ a .buildInvocation (pattern , args , blockStyle );
240
+ });
241
+ pattern .append (")" );
242
+ String builtPattern ;
243
+ if (processFunction .builderArg == null ) {
244
+ builtPattern = "result.$L(" + pattern + ")" ;
245
+ args .add (0 , appendMethod (resultDataType ));
246
+ } else {
247
+ builtPattern = pattern .toString ();
248
+ }
249
+ if (processFunction .warnExceptions .isEmpty () == false ) {
250
+ builder .beginControlFlow ("try" );
251
+ }
252
+
253
+ builder .addStatement ("driverContext.maybeCheckForEarlyTermination(" + estimateCost + ")" );
254
+ builder .addStatement (builtPattern , args .toArray ());
255
+
256
+ if (processFunction .warnExceptions .isEmpty () == false ) {
257
+ String catchPattern = "catch ("
258
+ + processFunction .warnExceptions .stream ().map (m -> "$T" ).collect (Collectors .joining (" | " ))
259
+ + " e)" ;
260
+ builder .nextControlFlow (catchPattern , processFunction .warnExceptions .stream ().map (m -> TypeName .get (m )).toArray ());
261
+ builder .addStatement ("warnings().registerException(e)" );
262
+ builder .addStatement ("result.appendNull()" );
221
263
builder .endControlFlow ();
222
264
}
223
265
}
224
- processFunction .args .stream ().forEach (a -> a .unpackValues (builder , blockStyle ));
266
+ builder .endControlFlow ();
267
+ }
268
+ builder .addStatement ("return result.build()" );
269
+ }
270
+ builder .endControlFlow ();
225
271
272
+ return builder .build ();
273
+ }
274
+
275
+ private void realEvalWithVectorizedStyle (MethodSpec .Builder builder , ClassName resultDataType ) {
276
+ // generate the tight loop to allow vectorization
277
+ builder .addStatement ("final int maxBatchSize = DriverContext.estimateBatchSizeForEarlyTermination(" + estimateCost + ")" );
278
+ builder .beginControlFlow ("for (int start = 0; start < positionCount; )" );
279
+ {
280
+ builder .addStatement ("int end = start + Math.min(positionCount - start, maxBatchSize)" );
281
+ builder .addStatement ("driverContext.checkForEarlyTermination()" );
282
+ builder .beginControlFlow ("for (int p = start; p < end; p++)" );
283
+ {
284
+ processFunction .args .forEach (a -> a .unpackValues (builder , false ));
226
285
StringBuilder pattern = new StringBuilder ();
227
286
List <Object > args = new ArrayList <>();
228
- pattern .append (processOutputsMultivalued ? "$T.$N(result, p, " : "$T.$N(" );
287
+ pattern .append ("$T.$N(" );
229
288
args .add (declarationType );
230
289
args .add (processFunction .function .getSimpleName ());
231
- processFunction .args .stream (). forEach (a -> {
290
+ processFunction .args .forEach (a -> {
232
291
if (args .size () > 2 ) {
233
292
pattern .append (", " );
234
293
}
235
- a .buildInvocation (pattern , args , blockStyle );
294
+ a .buildInvocation (pattern , args , false );
236
295
});
237
296
pattern .append (")" );
238
297
String builtPattern ;
239
298
if (processFunction .builderArg == null ) {
240
- builtPattern = vectorize ? "result.$L(p, " + pattern + ")" : "result.$L( " + pattern + ")" ;
299
+ builtPattern = "result.$L(p, " + pattern + ")" ;
241
300
args .add (0 , appendMethod (resultDataType ));
242
301
} else {
243
302
builtPattern = pattern .toString ();
244
303
}
245
- if (processFunction .warnExceptions .isEmpty () == false ) {
246
- builder .beginControlFlow ("try" );
247
- }
248
-
249
304
builder .addStatement (builtPattern , args .toArray ());
250
-
251
- if (processFunction .warnExceptions .isEmpty () == false ) {
252
- String catchPattern = "catch ("
253
- + processFunction .warnExceptions .stream ().map (m -> "$T" ).collect (Collectors .joining (" | " ))
254
- + " e)" ;
255
- builder .nextControlFlow (catchPattern , processFunction .warnExceptions .stream ().map (m -> TypeName .get (m )).toArray ());
256
- builder .addStatement ("warnings().registerException(e)" );
257
- builder .addStatement ("result.appendNull()" );
258
- builder .endControlFlow ();
259
- }
260
305
}
261
306
builder .endControlFlow ();
262
- builder .addStatement ("return result.build() " );
307
+ builder .addStatement ("start = end " );
263
308
}
264
309
builder .endControlFlow ();
265
-
266
- return builder .build ();
267
310
}
268
311
269
312
private static void skipNull (MethodSpec .Builder builder , String value ) {
0 commit comments