20
20
import org .elasticsearch .xpack .sql .expression .function .aggregate .Sum ;
21
21
import org .elasticsearch .xpack .sql .expression .function .aggregate .SumOfSquares ;
22
22
import org .elasticsearch .xpack .sql .expression .function .aggregate .VarPop ;
23
+ import org .elasticsearch .xpack .sql .expression .function .scalar .Cast ;
23
24
import org .elasticsearch .xpack .sql .expression .function .scalar .datetime .DayName ;
24
25
import org .elasticsearch .xpack .sql .expression .function .scalar .datetime .DayOfMonth ;
25
26
import org .elasticsearch .xpack .sql .expression .function .scalar .datetime .DayOfWeek ;
84
85
import org .elasticsearch .xpack .sql .expression .predicate .operator .arithmetic .Mod ;
85
86
import org .elasticsearch .xpack .sql .parser .ParsingException ;
86
87
import org .elasticsearch .xpack .sql .tree .Location ;
88
+ import org .elasticsearch .xpack .sql .type .DataType ;
87
89
import org .elasticsearch .xpack .sql .util .StringUtils ;
88
90
89
91
import java .util .Arrays ;
@@ -116,14 +118,14 @@ public class FunctionRegistry {
116
118
public FunctionRegistry () {
117
119
defineDefaultFunctions ();
118
120
}
119
-
121
+
120
122
/**
121
123
* Constructor specifying alternate functions for testing.
122
124
*/
123
125
FunctionRegistry (FunctionDefinition ... functions ) {
124
126
addToMap (functions );
125
127
}
126
-
128
+
127
129
private void defineDefaultFunctions () {
128
130
// Aggregate functions
129
131
addToMap (def (Avg .class , Avg ::new ),
@@ -206,11 +208,13 @@ private void defineDefaultFunctions() {
206
208
def (Space .class , Space ::new ),
207
209
def (Substring .class , Substring ::new ),
208
210
def (UCase .class , UCase ::new ));
211
+ // DataType conversion
212
+ addToMap (def (Cast .class , Cast ::new , "CONVERT" ));
209
213
// Special
210
214
addToMap (def (Score .class , Score ::new ));
211
215
}
212
-
213
- protected void addToMap (FunctionDefinition ...functions ) {
216
+
217
+ void addToMap (FunctionDefinition ...functions ) {
214
218
// temporary map to hold [function_name/alias_name : function instance]
215
219
Map <String , FunctionDefinition > batchMap = new HashMap <>();
216
220
for (FunctionDefinition f : functions ) {
@@ -227,7 +231,7 @@ protected void addToMap(FunctionDefinition...functions) {
227
231
// sort the temporary map by key name and add it to the global map of functions
228
232
defs .putAll (batchMap .entrySet ().stream ()
229
233
.sorted (Map .Entry .comparingByKey ())
230
- .collect (Collectors .<Entry <String , FunctionDefinition >, String ,
234
+ .collect (Collectors .<Entry <String , FunctionDefinition >, String ,
231
235
FunctionDefinition , LinkedHashMap <String , FunctionDefinition >> toMap (Map .Entry ::getKey , Map .Entry ::getValue ,
232
236
(oldValue , newValue ) -> oldValue , LinkedHashMap ::new )));
233
237
}
@@ -390,7 +394,7 @@ private static FunctionDefinition def(Class<? extends Function> function, Functi
390
394
private interface FunctionBuilder {
391
395
Function build (Location location , List <Expression > children , boolean distinct , TimeZone tz );
392
396
}
393
-
397
+
394
398
@ SuppressWarnings ("overloads" ) // These are ambiguous if you aren't using ctor references but we always do
395
399
static <T extends Function > FunctionDefinition def (Class <T > function ,
396
400
ThreeParametersFunctionBuilder <T > ctorRef , String ... aliases ) {
@@ -408,11 +412,11 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
408
412
};
409
413
return def (function , builder , false , aliases );
410
414
}
411
-
415
+
412
416
interface ThreeParametersFunctionBuilder <T > {
413
417
T build (Location location , Expression source , Expression exp1 , Expression exp2 );
414
418
}
415
-
419
+
416
420
@ SuppressWarnings ("overloads" ) // These are ambiguous if you aren't using ctor references but we always do
417
421
static <T extends Function > FunctionDefinition def (Class <T > function ,
418
422
FourParametersFunctionBuilder <T > ctorRef , String ... aliases ) {
@@ -427,11 +431,29 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
427
431
};
428
432
return def (function , builder , false , aliases );
429
433
}
430
-
434
+
431
435
interface FourParametersFunctionBuilder <T > {
432
436
T build (Location location , Expression source , Expression exp1 , Expression exp2 , Expression exp3 );
433
437
}
434
438
439
+ /**
440
+ * Special method to create function definition for {@link Cast} as its
441
+ * signature is not compatible with {@link UnresolvedFunction}
442
+ *
443
+ * @return Cast function definition
444
+ */
445
+ @ SuppressWarnings ("overloads" ) // These are ambiguous if you aren't using ctor references but we always do
446
+ private static <T extends Function > FunctionDefinition def (Class <T > function ,
447
+ CastFunctionBuilder <T > ctorRef ,
448
+ String ... aliases ) {
449
+ FunctionBuilder builder = (location , children , distinct , tz ) ->
450
+ ctorRef .build (location , children .get (0 ), children .get (0 ).dataType ());
451
+ return def (function , builder , false , aliases );
452
+ }
453
+ private interface CastFunctionBuilder <T > {
454
+ T build (Location location , Expression expression , DataType dataType );
455
+ }
456
+
435
457
private static String normalize (String name ) {
436
458
// translate CamelCase to camel_case
437
459
return StringUtils .camelCaseToUnderscore (name );
0 commit comments