diff --git a/src/query/ast/src/ast/expr.rs b/src/query/ast/src/ast/expr.rs index 9d0a40a72a88d..71395ad5fae1b 100644 --- a/src/query/ast/src/ast/expr.rs +++ b/src/query/ast/src/ast/expr.rs @@ -180,6 +180,7 @@ pub enum Expr { args: Vec, params: Vec, window: Option, + lambda: Option, }, /// `CASE ... WHEN ... ELSE ...` expression Case { @@ -376,6 +377,12 @@ pub enum WindowFrameBound { Following(Option>), } +#[derive(Debug, Clone, PartialEq)] +pub struct Lambda { + pub params: Vec, + pub expr: Box, +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum BinaryOperator { Plus, @@ -882,6 +889,21 @@ impl Display for WindowSpec { } } +impl Display for Lambda { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.params.len() == 1 { + write!(f, "{}", self.params[0])?; + } else { + write!(f, "(")?; + write_comma_separated_list(f, self.params.clone())?; + write!(f, ")")?; + } + write!(f, " -> {}", self.expr)?; + + Ok(()) + } +} + impl Display for Expr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -1043,6 +1065,7 @@ impl Display for Expr { args, params, window, + lambda, .. } => { write!(f, "{name}")?; @@ -1056,6 +1079,9 @@ impl Display for Expr { write!(f, "DISTINCT ")?; } write_comma_separated_list(f, args)?; + if let Some(lambda) = lambda { + write!(f, ", {lambda}")?; + } write!(f, ")")?; if let Some(window) = window { diff --git a/src/query/ast/src/ast/format/ast_format.rs b/src/query/ast/src/ast/format/ast_format.rs index 9626e7cb52911..f4c2d8f300654 100644 --- a/src/query/ast/src/ast/format/ast_format.rs +++ b/src/query/ast/src/ast/format/ast_format.rs @@ -431,6 +431,7 @@ impl<'ast> Visitor<'ast> for AstFormatVisitor { args: &'ast [Expr], _params: &'ast [Literal], _over: &'ast Option, + _lambda: &'ast Option, ) { let mut children = Vec::with_capacity(args.len()); for arg in args.iter() { diff --git a/src/query/ast/src/parser/expr.rs b/src/query/ast/src/parser/expr.rs index 91b8755dd5d72..52f96c423274e 100644 --- a/src/query/ast/src/parser/expr.rs +++ b/src/query/ast/src/parser/expr.rs @@ -291,8 +291,9 @@ pub enum ExprElement { distinct: bool, name: Identifier, args: Vec, - window: Option, params: Vec, + window: Option, + lambda: Option, }, /// `CASE ... WHEN ... ELSE ...` expression Case { @@ -491,6 +492,7 @@ impl<'a, I: Iterator>> PrattParser for ExprP args, params, window, + lambda, } => Expr::FunctionCall { span: transform_span(elem.span.0), distinct, @@ -498,6 +500,7 @@ impl<'a, I: Iterator>> PrattParser for ExprP args, params, window, + lambda, }, ExprElement::Case { operand, @@ -830,6 +833,25 @@ pub fn expr_element(i: Input) -> IResult> { args: opt_args.unwrap_or_default(), params: vec![], window: None, + lambda: None, + }, + ); + + let function_call_with_lambda = map( + rule! { + #function_name + ~ "(" ~ #subexpr(0) ~ "," ~ #ident ~ "->" ~ #subexpr(0) ~ ")" + }, + |(name, _, arg, _, param, _, expr, _)| ExprElement::FunctionCall { + distinct: false, + name, + args: vec![arg], + params: vec![], + window: None, + lambda: Some(Lambda { + params: vec![param], + expr: Box::new(expr), + }), }, ); @@ -845,6 +867,7 @@ pub fn expr_element(i: Input) -> IResult> { args: opt_args.unwrap_or_default(), params: vec![], window: Some(window.1), + lambda: None, }, ); @@ -860,6 +883,7 @@ pub fn expr_element(i: Input) -> IResult> { args: opt_args.unwrap_or_default(), params: params.map(|x| x.1).unwrap_or_default(), window: None, + lambda: None, }, ); @@ -1028,6 +1052,7 @@ pub fn expr_element(i: Input) -> IResult> { | #trim_from : "`TRIM([(BOTH | LEADEING | TRAILING) ... FROM ...)`" | #is_distinct_from: "`... IS [NOT] DISTINCT FROM ...`" | #count_all_with_window : "`COUNT(*) OVER ...`" + | #function_call_with_lambda : "" | #function_call_with_window : "" | #function_call_with_params : "" | #function_call : "" diff --git a/src/query/ast/src/visitors/visitor.rs b/src/query/ast/src/visitors/visitor.rs index 73af30318aa49..3068a1efb5a40 100644 --- a/src/query/ast/src/visitors/visitor.rs +++ b/src/query/ast/src/visitors/visitor.rs @@ -222,6 +222,7 @@ pub trait Visitor<'ast>: Sized { } } + #[allow(clippy::too_many_arguments)] fn visit_function_call( &mut self, _span: Span, @@ -230,6 +231,7 @@ pub trait Visitor<'ast>: Sized { args: &'ast [Expr], _params: &'ast [Literal], over: &'ast Option, + lambda: &'ast Option, ) { for arg in args { walk_expr(self, arg); @@ -238,6 +240,9 @@ pub trait Visitor<'ast>: Sized { if let Some(over) = over { self.visit_window(over); } + if let Some(lambda) = lambda { + walk_expr(self, &lambda.expr) + } } fn visit_window(&mut self, window: &'ast Window) { diff --git a/src/query/ast/src/visitors/visitor_mut.rs b/src/query/ast/src/visitors/visitor_mut.rs index 12be6d4cdea86..6532ad4494b14 100644 --- a/src/query/ast/src/visitors/visitor_mut.rs +++ b/src/query/ast/src/visitors/visitor_mut.rs @@ -236,6 +236,7 @@ pub trait VisitorMut: Sized { } } + #[allow(clippy::too_many_arguments)] fn visit_function_call( &mut self, _span: Span, @@ -244,6 +245,7 @@ pub trait VisitorMut: Sized { args: &mut [Expr], _params: &mut [Literal], over: &mut Option, + lambda: &mut Option, ) { for arg in args.iter_mut() { walk_expr_mut(self, arg); @@ -269,6 +271,9 @@ pub trait VisitorMut: Sized { } } } + if let Some(lambda) = lambda { + walk_expr_mut(self, &mut lambda.expr) + } } fn visit_frame_bound(&mut self, bound: &mut WindowFrameBound) { diff --git a/src/query/ast/src/visitors/walk.rs b/src/query/ast/src/visitors/walk.rs index 148f88ea0c5b5..8160f3cfdcd6d 100644 --- a/src/query/ast/src/visitors/walk.rs +++ b/src/query/ast/src/visitors/walk.rs @@ -94,7 +94,8 @@ pub fn walk_expr<'a, V: Visitor<'a>>(visitor: &mut V, expr: &'a Expr) { args, params, window, - } => visitor.visit_function_call(*span, *distinct, name, args, params, window), + lambda, + } => visitor.visit_function_call(*span, *distinct, name, args, params, window, lambda), Expr::Case { span, operand, diff --git a/src/query/ast/src/visitors/walk_mut.rs b/src/query/ast/src/visitors/walk_mut.rs index 6aadf5216f1fa..83a9ae75f6e32 100644 --- a/src/query/ast/src/visitors/walk_mut.rs +++ b/src/query/ast/src/visitors/walk_mut.rs @@ -94,7 +94,8 @@ pub fn walk_expr_mut(visitor: &mut V, expr: &mut Expr) { args, params, window, - } => visitor.visit_function_call(*span, *distinct, name, args, params, window), + lambda, + } => visitor.visit_function_call(*span, *distinct, name, args, params, window, lambda), Expr::Case { span, operand, diff --git a/src/query/ast/tests/it/parser.rs b/src/query/ast/tests/it/parser.rs index ec174d31146fe..6682c3cb80ee3 100644 --- a/src/query/ast/tests/it/parser.rs +++ b/src/query/ast/tests/it/parser.rs @@ -662,6 +662,8 @@ fn test_expr() { r#"COUNT() OVER (ORDER BY hire_date ROWS UNBOUNDED PRECEDING)"#, r#"COUNT() OVER (ORDER BY hire_date ROWS CURRENT ROW)"#, r#"COUNT() OVER (ORDER BY hire_date ROWS 3 PRECEDING)"#, + r#"ARRAY_APPLY([1,2,3], x -> x + 1)"#, + r#"ARRAY_FILTER(col, y -> y % 2 = 0)"#, ]; for case in cases { diff --git a/src/query/ast/tests/it/testdata/expr.txt b/src/query/ast/tests/it/testdata/expr.txt index b110608f945c5..e276611aba6ea 100644 --- a/src/query/ast/tests/it/testdata/expr.txt +++ b/src/query/ast/tests/it/testdata/expr.txt @@ -96,6 +96,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -760,6 +761,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -1303,6 +1305,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -1328,6 +1331,7 @@ FunctionCall { args: [], params: [], window: None, + lambda: None, } @@ -1351,6 +1355,7 @@ FunctionCall { args: [], params: [], window: None, + lambda: None, } @@ -1407,6 +1412,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -2075,6 +2081,7 @@ BinaryOp { ], params: [], window: None, + lambda: None, }, }, not: true, @@ -2182,6 +2189,7 @@ BinaryOp { ], params: [], window: None, + lambda: None, }, right: Case { span: Some( @@ -2226,6 +2234,7 @@ BinaryOp { ], params: [], window: None, + lambda: None, }, right: Literal { span: Some( @@ -2280,6 +2289,7 @@ BinaryOp { ], params: [], window: None, + lambda: None, }, ), }, @@ -2712,6 +2722,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -2768,6 +2779,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -2816,6 +2828,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -2888,6 +2901,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -2928,6 +2942,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -2984,6 +2999,7 @@ FunctionCall { ], params: [], window: None, + lambda: None, } @@ -3190,6 +3206,7 @@ FunctionCall { }, ), ), + lambda: None, } @@ -3239,6 +3256,7 @@ FunctionCall { }, ), ), + lambda: None, } @@ -3305,6 +3323,7 @@ FunctionCall { }, ), ), + lambda: None, } @@ -3402,6 +3421,7 @@ FunctionCall { }, ), ), + lambda: None, } @@ -3506,6 +3526,7 @@ FunctionCall { }, ), ), + lambda: None, } @@ -3582,6 +3603,7 @@ FunctionCall { }, ), ), + lambda: None, } @@ -3643,6 +3665,7 @@ FunctionCall { }, ), ), + lambda: None, } @@ -3702,6 +3725,7 @@ FunctionCall { }, ), ), + lambda: None, } @@ -3772,6 +3796,202 @@ FunctionCall { }, ), ), + lambda: None, +} + + +---------- Input ---------- +ARRAY_APPLY([1,2,3], x -> x + 1) +---------- Output --------- +ARRAY_APPLY([1, 2, 3], x -> (x + 1)) +---------- AST ------------ +FunctionCall { + span: Some( + 0..32, + ), + distinct: false, + name: Identifier { + name: "ARRAY_APPLY", + quote: None, + span: Some( + 0..11, + ), + }, + args: [ + Array { + span: Some( + 12..19, + ), + exprs: [ + Literal { + span: Some( + 13..14, + ), + lit: UInt64( + 1, + ), + }, + Literal { + span: Some( + 15..16, + ), + lit: UInt64( + 2, + ), + }, + Literal { + span: Some( + 17..18, + ), + lit: UInt64( + 3, + ), + }, + ], + }, + ], + params: [], + window: None, + lambda: Some( + Lambda { + params: [ + Identifier { + name: "x", + quote: None, + span: Some( + 21..22, + ), + }, + ], + expr: BinaryOp { + span: Some( + 28..29, + ), + op: Plus, + left: ColumnRef { + span: Some( + 26..27, + ), + database: None, + table: None, + column: Name( + Identifier { + name: "x", + quote: None, + span: Some( + 26..27, + ), + }, + ), + }, + right: Literal { + span: Some( + 30..31, + ), + lit: UInt64( + 1, + ), + }, + }, + }, + ), +} + + +---------- Input ---------- +ARRAY_FILTER(col, y -> y % 2 = 0) +---------- Output --------- +ARRAY_FILTER(col, y -> ((y % 2) = 0)) +---------- AST ------------ +FunctionCall { + span: Some( + 0..33, + ), + distinct: false, + name: Identifier { + name: "ARRAY_FILTER", + quote: None, + span: Some( + 0..12, + ), + }, + args: [ + ColumnRef { + span: Some( + 13..16, + ), + database: None, + table: None, + column: Name( + Identifier { + name: "col", + quote: None, + span: Some( + 13..16, + ), + }, + ), + }, + ], + params: [], + window: None, + lambda: Some( + Lambda { + params: [ + Identifier { + name: "y", + quote: None, + span: Some( + 18..19, + ), + }, + ], + expr: BinaryOp { + span: Some( + 29..30, + ), + op: Eq, + left: BinaryOp { + span: Some( + 25..26, + ), + op: Modulo, + left: ColumnRef { + span: Some( + 23..24, + ), + database: None, + table: None, + column: Name( + Identifier { + name: "y", + quote: None, + span: Some( + 23..24, + ), + }, + ), + }, + right: Literal { + span: Some( + 27..28, + ), + lit: UInt64( + 2, + ), + }, + }, + right: Literal { + span: Some( + 31..32, + ), + lit: UInt64( + 0, + ), + }, + }, + }, + ), } diff --git a/src/query/ast/tests/it/testdata/query.txt b/src/query/ast/tests/it/testdata/query.txt index 282e46084e9d4..3f7496730ab16 100644 --- a/src/query/ast/tests/it/testdata/query.txt +++ b/src/query/ast/tests/it/testdata/query.txt @@ -2473,6 +2473,7 @@ Query { ], params: [], window: None, + lambda: None, }, alias: Some( Identifier { @@ -2602,6 +2603,7 @@ Query { ], params: [], window: None, + lambda: None, }, alias: None, }, @@ -4484,6 +4486,7 @@ Query { ], params: [], window: None, + lambda: None, }, value_column: Identifier { name: "month", @@ -4834,6 +4837,7 @@ Query { }, ), ), + lambda: None, }, alias: None, }, @@ -5009,6 +5013,7 @@ Query { }, ), ), + lambda: None, }, alias: None, }, @@ -5057,6 +5062,7 @@ Query { }, ), ), + lambda: None, }, alias: None, }, @@ -5105,6 +5111,7 @@ Query { }, ), ), + lambda: None, }, alias: None, }, diff --git a/src/query/ast/tests/it/testdata/statement.txt b/src/query/ast/tests/it/testdata/statement.txt index 6cdfdeb9af4cd..27bedcef0c4cd 100644 --- a/src/query/ast/tests/it/testdata/statement.txt +++ b/src/query/ast/tests/it/testdata/statement.txt @@ -1005,6 +1005,7 @@ CreateTable( ], params: [], window: None, + lambda: None, }, ), ), @@ -6766,6 +6767,7 @@ Query( ], params: [], window: None, + lambda: None, }, accessor: Period { key: Identifier { @@ -11575,6 +11577,7 @@ CreateDatamaskPolicy( args: [], params: [], window: None, + lambda: None, }, list: [ Literal { diff --git a/src/query/functions/src/lib.rs b/src/query/functions/src/lib.rs index fa90e063a6b26..e569db1c2681d 100644 --- a/src/query/functions/src/lib.rs +++ b/src/query/functions/src/lib.rs @@ -30,6 +30,7 @@ pub fn is_builtin_function(name: &str) -> bool { BUILTIN_FUNCTIONS.contains(name) || AggregateFunctionFactory::instance().contains(name) || GENERAL_WINDOW_FUNCTIONS.contains(&name) + || GENERAL_LAMBDA_FUNCTIONS.contains(&name) } #[ctor] @@ -51,6 +52,8 @@ pub const GENERAL_WINDOW_FUNCTIONS: [&str; 13] = [ "cume_dist", ]; +pub const GENERAL_LAMBDA_FUNCTIONS: [&str; 3] = ["array_transform", "array_apply", "array_filter"]; + fn builtin_functions() -> FunctionRegistry { let mut registry = FunctionRegistry::empty(); diff --git a/src/query/profile/src/prof.rs b/src/query/profile/src/prof.rs index 9f423e46ef019..bad903b9bda06 100644 --- a/src/query/profile/src/prof.rs +++ b/src/query/profile/src/prof.rs @@ -62,6 +62,7 @@ pub enum OperatorType { Filter, ProjectSet, EvalScalar, + Lambda, Limit, TableScan, CteScan, @@ -84,6 +85,7 @@ impl Display for OperatorType { OperatorType::Filter => write!(f, "Filter"), OperatorType::ProjectSet => write!(f, "ProjectSet"), OperatorType::EvalScalar => write!(f, "EvalScalar"), + OperatorType::Lambda => write!(f, "Lambda"), OperatorType::Limit => write!(f, "Limit"), OperatorType::TableScan => write!(f, "TableScan"), OperatorType::Sort => write!(f, "Sort"), @@ -134,6 +136,7 @@ pub enum OperatorAttribute { Filter(FilterAttribute), EvalScalar(EvalScalarAttribute), ProjectSet(ProjectSetAttribute), + Lambda(LambdaAttribute), Limit(LimitAttribute), TableScan(TableScanAttribute), Sort(SortAttribute), @@ -172,6 +175,11 @@ pub struct ProjectSetAttribute { pub functions: String, } +#[derive(Debug, Clone)] +pub struct LambdaAttribute { + pub scalars: String, +} + #[derive(Debug, Clone)] pub struct FilterAttribute { pub predicate: String, diff --git a/src/query/service/src/pipelines/pipeline_builder.rs b/src/query/service/src/pipelines/pipeline_builder.rs index 89911f51e43b5..1f487a1b7f9ca 100644 --- a/src/query/service/src/pipelines/pipeline_builder.rs +++ b/src/query/service/src/pipelines/pipeline_builder.rs @@ -67,6 +67,7 @@ use common_sql::executor::ExchangeSink; use common_sql::executor::ExchangeSource; use common_sql::executor::Filter; use common_sql::executor::HashJoin; +use common_sql::executor::Lambda; use common_sql::executor::Limit; use common_sql::executor::MaterializedCte; use common_sql::executor::PhysicalPlan; @@ -208,6 +209,7 @@ impl PipelineBuilder { self.build_distributed_insert_select(insert_select) } PhysicalPlan::ProjectSet(project_set) => self.build_project_set(project_set), + PhysicalPlan::Lambda(lambda) => self.build_lambda(lambda), PhysicalPlan::Exchange(_) => Err(ErrorCode::Internal( "Invalid physical plan with PhysicalPlan::Exchange", )), @@ -750,6 +752,39 @@ impl PipelineBuilder { }) } + fn build_lambda(&mut self, lambda: &Lambda) -> Result<()> { + self.build_pipeline(&lambda.input)?; + + let funcs = lambda.lambda_funcs.clone(); + let op = BlockOperator::LambdaMap { funcs }; + + let input_schema = lambda.input.output_schema()?; + let func_ctx = self.ctx.get_function_context()?; + + let num_input_columns = input_schema.num_fields(); + + self.main_pipeline.add_transform(|input, output| { + let transform = + CompoundBlockOperator::new(vec![op.clone()], func_ctx.clone(), num_input_columns); + + if self.enable_profiling { + Ok(ProcessorPtr::create(TransformProfileWrapper::create( + transform, + input, + output, + lambda.plan_id, + self.proc_profs.clone(), + ))) + } else { + Ok(ProcessorPtr::create(Transformer::create( + input, output, transform, + ))) + } + })?; + + Ok(()) + } + fn build_aggregate_expand(&mut self, expand: &AggregateExpand) -> Result<()> { self.build_pipeline(&expand.input)?; let input_schema = expand.input.output_schema()?; diff --git a/src/query/sql/src/evaluator/block_operator.rs b/src/query/sql/src/evaluator/block_operator.rs index f2e418f85444e..6ee512262e50c 100644 --- a/src/query/sql/src/evaluator/block_operator.rs +++ b/src/query/sql/src/evaluator/block_operator.rs @@ -17,6 +17,8 @@ use std::sync::Arc; use common_catalog::plan::AggIndexMeta; use common_exception::Result; +use common_expression::types::array::ArrayColumn; +use common_expression::types::nullable::NullableColumn; use common_expression::types::nullable::NullableColumnBuilder; use common_expression::types::BooleanType; use common_expression::types::DataType; @@ -30,6 +32,7 @@ use common_expression::Evaluator; use common_expression::Expr; use common_expression::FieldIndex; use common_expression::FunctionContext; +use common_expression::Scalar; use common_expression::ScalarRef; use common_expression::Value; use common_functions::BUILTIN_FUNCTIONS; @@ -39,6 +42,7 @@ use common_pipeline_core::processors::Processor; use common_pipeline_transforms::processors::transforms::Transform; use common_pipeline_transforms::processors::transforms::Transformer; +use crate::executor::LambdaFunctionDesc; use crate::IndexType; /// `BlockOperator` takes a `DataBlock` as input and produces a `DataBlock` as output. @@ -64,6 +68,9 @@ pub enum BlockOperator { srf_exprs: Vec, unused_indices: HashSet, }, + + /// Execute lambda function on input [`DataBlock`]. + LambdaMap { funcs: Vec }, } impl BlockOperator { @@ -317,6 +324,133 @@ impl BlockOperator { } Ok(result) } + + BlockOperator::LambdaMap { funcs } => { + for func in funcs { + let expr = func.lambda_expr.as_expr(&BUILTIN_FUNCTIONS); + // TODO: Support multi args + let input_column = input.get_by_offset(func.arg_indices[0]); + match &input_column.value { + Value::Scalar(s) => match s { + Scalar::Null => { + let col = BlockEntry::new( + expr.data_type().clone(), + input_column.value.clone(), + ); + input.add_column(col); + } + Scalar::Array(c) => { + let entry = + BlockEntry::new(c.data_type(), Value::Column(c.clone())); + let block = DataBlock::new(vec![entry], c.len()); + + let evaluator = + Evaluator::new(&block, func_ctx, &BUILTIN_FUNCTIONS); + let result = evaluator.run(&expr)?; + let result_col = + result.convert_to_full_column(expr.data_type(), c.len()); + + let col = if func.func_name == "array_filter" { + let result_col = result_col.remove_nullable(); + let bitmap = result_col.as_boolean().unwrap(); + let filtered_inner_col = c.filter(bitmap); + BlockEntry::new( + input_column.data_type.clone(), + Value::Scalar(Scalar::Array(filtered_inner_col)), + ) + } else { + BlockEntry::new( + DataType::Array(Box::new(expr.data_type().clone())), + Value::Scalar(Scalar::Array(result_col)), + ) + }; + input.add_column(col); + } + _ => unreachable!(), + }, + Value::Column(c) => { + let (inner_col, inner_ty, offsets, validity) = match c { + Column::Array(box array_col) => ( + array_col.values.clone(), + array_col.values.data_type(), + array_col.offsets.clone(), + None, + ), + Column::Nullable(box nullable_col) => match &nullable_col.column { + Column::Array(box array_col) => ( + array_col.values.clone(), + array_col.values.data_type(), + array_col.offsets.clone(), + Some(nullable_col.validity.clone()), + ), + _ => unreachable!(), + }, + _ => unreachable!(), + }; + + let entry = BlockEntry::new(inner_ty, Value::Column(inner_col.clone())); + let block = DataBlock::new(vec![entry], inner_col.len()); + + let evaluator = Evaluator::new(&block, func_ctx, &BUILTIN_FUNCTIONS); + let result = evaluator.run(&expr)?; + let result_col = + result.convert_to_full_column(expr.data_type(), inner_col.len()); + + let col = if func.func_name == "array_filter" { + let result_col = result_col.remove_nullable(); + let bitmap = result_col.as_boolean().unwrap(); + let filtered_inner_col = inner_col.filter(bitmap); + // generate new offsets after filter. + let mut new_offset = 0; + let mut filtered_offsets = Vec::with_capacity(offsets.len()); + filtered_offsets.push(0); + for offset in offsets.windows(2) { + let off = offset[0] as usize; + let len = (offset[1] - offset[0]) as usize; + let unset_count = bitmap.null_count_range(off, len); + new_offset += (len - unset_count) as u64; + filtered_offsets.push(new_offset); + } + + let array_col = Column::Array(Box::new(ArrayColumn { + values: filtered_inner_col, + offsets: filtered_offsets.into(), + })); + let col = match validity { + Some(validity) => { + Value::Column(Column::Nullable(Box::new(NullableColumn { + column: array_col, + validity, + }))) + } + None => Value::Column(array_col), + }; + BlockEntry::new(input_column.data_type.clone(), col) + } else { + let array_col = Column::Array(Box::new(ArrayColumn { + values: result_col, + offsets, + })); + let array_ty = DataType::Array(Box::new(expr.data_type().clone())); + let (ty, col) = match validity { + Some(validity) => ( + DataType::Nullable(Box::new(array_ty)), + Value::Column(Column::Nullable(Box::new(NullableColumn { + column: array_col, + validity, + }))), + ), + None => (array_ty, Value::Column(array_col)), + }; + BlockEntry::new(ty, col) + }; + input.add_column(col); + } + } + } + + Ok(input) + } } } } @@ -395,6 +529,7 @@ impl Transform for CompoundBlockOperator { BlockOperator::Filter { .. } => "Filter", BlockOperator::Project { .. } => "Project", BlockOperator::FlatMap { .. } => "FlatMap", + BlockOperator::LambdaMap { .. } => "LambdaMap", } .to_string() }) diff --git a/src/query/sql/src/executor/format.rs b/src/query/sql/src/executor/format.rs index c0750092362e3..0e3133d59d7bd 100644 --- a/src/query/sql/src/executor/format.rs +++ b/src/query/sql/src/executor/format.rs @@ -31,6 +31,7 @@ use super::EvalScalar; use super::Exchange; use super::Filter; use super::HashJoin; +use super::Lambda; use super::Limit; use super::PhysicalPlan; use super::Project; @@ -183,6 +184,7 @@ fn to_format_tree( delete_final_to_format_tree(plan.as_ref(), metadata, profs) } PhysicalPlan::ProjectSet(plan) => project_set_to_format_tree(plan, metadata, profs), + PhysicalPlan::Lambda(plan) => lambda_to_format_tree(plan, metadata, profs), PhysicalPlan::RuntimeFilterSource(plan) => { runtime_filter_source_to_format_tree(plan, metadata, profs) } @@ -1009,6 +1011,45 @@ fn project_set_to_format_tree( )) } +fn lambda_to_format_tree( + plan: &Lambda, + metadata: &MetadataRef, + prof_span_set: &SharedProcessorProfiles, +) -> Result> { + let mut children = vec![]; + + if let Some(info) = &plan.stat_info { + let items = plan_stats_info_to_format_tree(info); + children.extend(items); + } + + append_profile_info(&mut children, prof_span_set, plan.plan_id); + + children.extend(vec![FormatTreeNode::new(format!( + "lambda functions: {}", + plan.lambda_funcs + .iter() + .map(|func| { + let arg_exprs = func.arg_exprs.join(", "); + let params = func.params.join(", "); + let lambda_expr = func.lambda_expr.as_expr(&BUILTIN_FUNCTIONS).sql_display(); + format!( + "{}({}, {} -> {})", + func.func_name, arg_exprs, params, lambda_expr + ) + }) + .collect::>() + .join(", ") + ))]); + + children.extend(vec![to_format_tree(&plan.input, metadata, prof_span_set)?]); + + Ok(FormatTreeNode::with_children( + "Lambda".to_string(), + children, + )) +} + fn runtime_filter_source_to_format_tree( plan: &RuntimeFilterSource, metadata: &MetadataRef, diff --git a/src/query/sql/src/executor/physical_plan.rs b/src/query/sql/src/executor/physical_plan.rs index 7640fa169f2b2..34b6f8d19c775 100644 --- a/src/query/sql/src/executor/physical_plan.rs +++ b/src/query/sql/src/executor/physical_plan.rs @@ -432,6 +432,43 @@ impl Window { } } +#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct LambdaFunctionDesc { + pub func_name: String, + pub output_column: IndexType, + pub arg_indices: Vec, + pub arg_exprs: Vec, + pub params: Vec, + pub lambda_expr: RemoteExpr, + pub data_type: Box, +} + +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct Lambda { + /// A unique id of operator in a `PhysicalPlan` tree. + /// Only used for display. + pub plan_id: u32, + + pub input: Box, + pub lambda_funcs: Vec, + + /// Only used for explain + pub stat_info: Option, +} + +impl Lambda { + pub fn output_schema(&self) -> Result { + let input_schema = self.input.output_schema()?; + let mut fields = input_schema.fields().clone(); + for lambda_func in self.lambda_funcs.iter() { + let name = lambda_func.output_column.to_string(); + let data_type = lambda_func.data_type.clone(); + fields.push(DataField::new(&name, *data_type)); + } + Ok(DataSchemaRefExt::create(fields)) + } +} + #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct Sort { /// A unique id of operator in a `PhysicalPlan` tree. @@ -915,6 +952,7 @@ pub enum PhysicalPlan { AggregatePartial(AggregatePartial), AggregateFinal(AggregateFinal), Window(Window), + Lambda(Lambda), Sort(Sort), Limit(Limit), RowFetch(RowFetch), @@ -962,6 +1000,7 @@ impl PhysicalPlan { PhysicalPlan::AggregatePartial(v) => v.plan_id, PhysicalPlan::AggregateFinal(v) => v.plan_id, PhysicalPlan::Window(v) => v.plan_id, + PhysicalPlan::Lambda(v) => v.plan_id, PhysicalPlan::Sort(v) => v.plan_id, PhysicalPlan::Limit(v) => v.plan_id, PhysicalPlan::RowFetch(v) => v.plan_id, @@ -992,6 +1031,7 @@ impl PhysicalPlan { PhysicalPlan::AggregatePartial(plan) => plan.output_schema(), PhysicalPlan::AggregateFinal(plan) => plan.output_schema(), PhysicalPlan::Window(plan) => plan.output_schema(), + PhysicalPlan::Lambda(plan) => plan.output_schema(), PhysicalPlan::Sort(plan) => plan.output_schema(), PhysicalPlan::Limit(plan) => plan.output_schema(), PhysicalPlan::RowFetch(plan) => plan.output_schema(), @@ -1023,6 +1063,7 @@ impl PhysicalPlan { PhysicalPlan::AggregatePartial(_) => "AggregatePartial".to_string(), PhysicalPlan::AggregateFinal(_) => "AggregateFinal".to_string(), PhysicalPlan::Window(_) => "Window".to_string(), + PhysicalPlan::Lambda(_) => "Lambda".to_string(), PhysicalPlan::Sort(_) => "Sort".to_string(), PhysicalPlan::Limit(_) => "Limit".to_string(), PhysicalPlan::RowFetch(_) => "RowFetch".to_string(), @@ -1056,6 +1097,7 @@ impl PhysicalPlan { PhysicalPlan::AggregatePartial(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::AggregateFinal(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::Window(plan) => Box::new(std::iter::once(plan.input.as_ref())), + PhysicalPlan::Lambda(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::Sort(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::Limit(plan) => Box::new(std::iter::once(plan.input.as_ref())), PhysicalPlan::RowFetch(plan) => Box::new(std::iter::once(plan.input.as_ref())), @@ -1097,6 +1139,7 @@ impl PhysicalPlan { PhysicalPlan::Project(plan) => plan.input.try_find_single_data_source(), PhysicalPlan::EvalScalar(plan) => plan.input.try_find_single_data_source(), PhysicalPlan::Window(plan) => plan.input.try_find_single_data_source(), + PhysicalPlan::Lambda(plan) => plan.input.try_find_single_data_source(), PhysicalPlan::Sort(plan) => plan.input.try_find_single_data_source(), PhysicalPlan::Limit(plan) => plan.input.try_find_single_data_source(), PhysicalPlan::Exchange(plan) => plan.input.try_find_single_data_source(), diff --git a/src/query/sql/src/executor/physical_plan_builder.rs b/src/query/sql/src/executor/physical_plan_builder.rs index 76e008173ae31..411076a1ddff4 100644 --- a/src/query/sql/src/executor/physical_plan_builder.rs +++ b/src/query/sql/src/executor/physical_plan_builder.rs @@ -14,6 +14,7 @@ use std::collections::btree_map::Entry; use std::collections::BTreeMap; +use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; @@ -53,6 +54,7 @@ use super::AggregatePartial; use super::EvalScalar; use super::Exchange as PhysicalExchange; use super::Filter; +use super::Lambda; use super::Limit; use super::NthValueFunctionDesc; use super::ProjectSet; @@ -70,6 +72,7 @@ use crate::executor::CteScan; use crate::executor::FragmentKind; use crate::executor::LagLeadDefault; use crate::executor::LagLeadFunctionDesc; +use crate::executor::LambdaFunctionDesc; use crate::executor::MaterializedCte; use crate::executor::NtileFunctionDesc; use crate::executor::PhysicalJoinType; @@ -1115,6 +1118,111 @@ impl PhysicalPlanBuilder { })) } + RelOperator::Lambda(lambda) => { + let input = self.build(s_expr.child(0)?).await?; + let input_schema = input.output_schema()?; + let mut index = input_schema.num_fields(); + let mut lambda_index_map = HashMap::new(); + let lambda_funcs = lambda + .items + .iter() + .map(|item| { + if let ScalarExpr::LambdaFunction(func) = &item.scalar { + let arg_indices = func + .args + .iter() + .map(|arg| { + match arg { + ScalarExpr::BoundColumnRef(col) => { + let index = input_schema + .index_of(&col.column.index.to_string()) + .unwrap(); + Ok(index) + } + ScalarExpr::LambdaFunction(inner_func) => { + // nested lambda function as an argument of parent lambda function + let index = lambda_index_map.get(&inner_func.display_name).unwrap(); + Ok(*index) + } + _ => { + Err(ErrorCode::Internal( + "lambda function's argument must be a BoundColumnRef or LambdaFunction" + .to_string(), + )) + } + } + }) + .collect::>>()?; + + lambda_index_map.insert( + func.display_name.clone(), + index, + ); + index += 1; + + let arg_exprs = func + .args + .iter() + .map(|arg| { + let expr = arg.as_expr()?; + let remote_expr = expr.as_remote_expr(); + Ok(remote_expr.as_expr(&BUILTIN_FUNCTIONS).sql_display()) + }) + .collect::>>()?; + + let params = func + .params + .iter() + .map(|(param, _)| param.clone()) + .collect::>(); + + // build schema for lambda expr. + let mut field_index = 0; + let lambda_fields = func + .params + .iter() + .map(|(_, ty)| { + let field = DataField::new(&field_index.to_string(), ty.clone()); + field_index += 1; + field + }) + .collect::>(); + let lambda_schema = DataSchema::new(lambda_fields); + + let expr = func + .lambda_expr + .resolve_and_check(&lambda_schema)? + .project_column_ref(|index| { + lambda_schema.index_of(&index.to_string()).unwrap() + }); + let (expr, _) = + ConstantFolder::fold(&expr, &self.func_ctx, &BUILTIN_FUNCTIONS); + let lambda_expr = expr.as_remote_expr(); + + let lambda_func = LambdaFunctionDesc { + func_name: func.func_name.clone(), + output_column: item.index, + arg_indices, + arg_exprs, + params, + lambda_expr, + data_type: func.return_type.clone(), + }; + Ok(lambda_func) + } else { + Err(ErrorCode::Internal("Expected lambda function".to_string())) + } + }) + .collect::>>()?; + + Ok(PhysicalPlan::Lambda(Lambda { + plan_id: self.next_plan_id(), + input: Box::new(input), + lambda_funcs, + stat_info: Some(stat_info), + })) + } + _ => Err(ErrorCode::Internal(format!( "Unsupported physical plan: {:?}", s_expr.plan() diff --git a/src/query/sql/src/executor/physical_plan_display.rs b/src/query/sql/src/executor/physical_plan_display.rs index 81e1f0b7f7c7c..4470eaa4369be 100644 --- a/src/query/sql/src/executor/physical_plan_display.rs +++ b/src/query/sql/src/executor/physical_plan_display.rs @@ -35,6 +35,7 @@ use crate::executor::ExchangeSink; use crate::executor::ExchangeSource; use crate::executor::Filter; use crate::executor::HashJoin; +use crate::executor::Lambda; use crate::executor::Limit; use crate::executor::MaterializedCte; use crate::executor::PhysicalPlan; @@ -83,6 +84,7 @@ impl<'a> Display for PhysicalPlanIndentFormatDisplay<'a> { PhysicalPlan::DeletePartial(delete) => write!(f, "{}", delete)?, PhysicalPlan::DeleteFinal(delete) => write!(f, "{}", delete)?, PhysicalPlan::ProjectSet(unnest) => write!(f, "{}", unnest)?, + PhysicalPlan::Lambda(lambda) => write!(f, "{}", lambda)?, PhysicalPlan::RuntimeFilterSource(plan) => write!(f, "{}", plan)?, PhysicalPlan::RangeJoin(plan) => write!(f, "{}", plan)?, PhysicalPlan::DistributedCopyIntoTableFromStage(copy_into_table_from_stage) => { @@ -418,3 +420,22 @@ impl Display for ProjectSet { ) } } + +impl Display for Lambda { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let scalars = self + .lambda_funcs + .iter() + .map(|func| { + let arg_exprs = func.arg_exprs.join(", "); + let params = func.params.join(", "); + let lambda_expr = func.lambda_expr.as_expr(&BUILTIN_FUNCTIONS).sql_display(); + format!( + "{}({}, {} -> {})", + func.func_name, arg_exprs, params, lambda_expr + ) + }) + .collect::>(); + write!(f, "Lambda functions : {}", scalars.join(", ")) + } +} diff --git a/src/query/sql/src/executor/physical_plan_visitor.rs b/src/query/sql/src/executor/physical_plan_visitor.rs index 7ce9833530fac..ce4772f34376c 100644 --- a/src/query/sql/src/executor/physical_plan_visitor.rs +++ b/src/query/sql/src/executor/physical_plan_visitor.rs @@ -28,6 +28,7 @@ use super::ExchangeSink; use super::ExchangeSource; use super::Filter; use super::HashJoin; +use super::Lambda; use super::Limit; use super::PhysicalPlan; use super::Project; @@ -64,6 +65,7 @@ pub trait PhysicalPlanReplacer { PhysicalPlan::UnionAll(plan) => self.replace_union(plan), PhysicalPlan::DistributedInsertSelect(plan) => self.replace_insert_select(plan), PhysicalPlan::ProjectSet(plan) => self.replace_project_set(plan), + PhysicalPlan::Lambda(plan) => self.replace_lambda(plan), PhysicalPlan::RuntimeFilterSource(plan) => self.replace_runtime_filter_source(plan), PhysicalPlan::DeletePartial(plan) => self.replace_delete_partial(plan), PhysicalPlan::DeleteFinal(plan) => self.replace_delete_final(plan), @@ -368,6 +370,16 @@ pub trait PhysicalPlanReplacer { })) } + fn replace_lambda(&mut self, plan: &Lambda) -> Result { + let input = self.replace(&plan.input)?; + Ok(PhysicalPlan::Lambda(Lambda { + plan_id: plan.plan_id, + input: Box::new(input), + lambda_funcs: plan.lambda_funcs.clone(), + stat_info: plan.stat_info.clone(), + })) + } + fn replace_runtime_filter_source( &mut self, plan: &RuntimeFilterSource, @@ -446,6 +458,9 @@ impl PhysicalPlan { PhysicalPlan::ProjectSet(plan) => { Self::traverse(&plan.input, pre_visit, visit, post_visit) } + PhysicalPlan::Lambda(plan) => { + Self::traverse(&plan.input, pre_visit, visit, post_visit) + } PhysicalPlan::DistributedCopyIntoTableFromStage(_) => {} PhysicalPlan::CopyIntoTableFromQuery(plan) => { Self::traverse(&plan.input, pre_visit, visit, post_visit); diff --git a/src/query/sql/src/executor/profile.rs b/src/query/sql/src/executor/profile.rs index 521defc2b5d7a..0ae5875634aa8 100644 --- a/src/query/sql/src/executor/profile.rs +++ b/src/query/sql/src/executor/profile.rs @@ -21,6 +21,7 @@ use common_profile::EvalScalarAttribute; use common_profile::ExchangeAttribute; use common_profile::FilterAttribute; use common_profile::JoinAttribute; +use common_profile::LambdaAttribute; use common_profile::LimitAttribute; use common_profile::OperatorAttribute; use common_profile::OperatorProfile; @@ -153,6 +154,33 @@ fn flatten_plan_node_profile( }; plan_node_profs.push(prof); } + PhysicalPlan::Lambda(lambda) => { + flatten_plan_node_profile(metadata, &lambda.input, profs, plan_node_profs)?; + let proc_prof = profs.get(&lambda.plan_id).copied().unwrap_or_default(); + let prof = OperatorProfile { + id: lambda.plan_id, + operator_type: OperatorType::Lambda, + execution_info: proc_prof.into(), + children: vec![lambda.input.get_id()], + attribute: OperatorAttribute::Lambda(LambdaAttribute { + scalars: lambda + .lambda_funcs + .iter() + .map(|func| { + let arg_exprs = func.arg_exprs.join(", "); + let params = func.params.join(", "); + let lambda_expr = + func.lambda_expr.as_expr(&BUILTIN_FUNCTIONS).sql_display(); + format!( + "{}({}, {} -> {})", + func.func_name, arg_exprs, params, lambda_expr + ) + }) + .join(", "), + }), + }; + plan_node_profs.push(prof); + } PhysicalPlan::AggregateExpand(expand) => { flatten_plan_node_profile(metadata, &expand.input, profs, plan_node_profs)?; let proc_prof = profs.get(&expand.plan_id).copied().unwrap_or_default(); diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index ff69c5de632e3..e7316bf366527 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -45,6 +45,7 @@ use crate::plans::CastExpr; use crate::plans::EvalScalar; use crate::plans::FunctionCall; use crate::plans::LagLeadFunction; +use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; @@ -210,6 +211,25 @@ impl<'a> AggregateRewriter<'a> { } .into()) } + + ScalarExpr::LambdaFunction(lambda_func) => { + let new_args = lambda_func + .args + .iter() + .map(|arg| self.visit(arg)) + .collect::>>()?; + + Ok(LambdaFunc { + span: lambda_func.span, + func_name: lambda_func.func_name.clone(), + display_name: lambda_func.display_name.clone(), + args: new_args, + params: lambda_func.params.clone(), + lambda_expr: lambda_func.lambda_expr.clone(), + return_type: lambda_func.return_type.clone(), + } + .into()) + } } } diff --git a/src/query/sql/src/planner/binder/bind_context.rs b/src/query/sql/src/planner/binder/bind_context.rs index 644c87a2bf45e..80b60f9655724 100644 --- a/src/query/sql/src/planner/binder/bind_context.rs +++ b/src/query/sql/src/planner/binder/bind_context.rs @@ -36,6 +36,7 @@ use enum_as_inner::EnumAsInner; use super::AggregateInfo; use super::INTERNAL_COLUMN_FACTORY; use crate::binder::column_binding::ColumnBinding; +use crate::binder::lambda::LambdaInfo; use crate::binder::window::WindowInfo; use crate::binder::ColumnBindingBuilder; use crate::normalize_identifier; @@ -60,6 +61,7 @@ pub enum ExprContext { InSetReturningFunction, InAggregateFunction, + InLambdaFunction, #[default] Unknown, @@ -114,6 +116,8 @@ pub struct BindContext { pub windows: WindowInfo, + pub lambda_info: LambdaInfo, + /// True if there is aggregation in current context, which means /// non-grouping columns cannot be referenced outside aggregation /// functions, otherwise a grouping error will be raised. @@ -163,6 +167,7 @@ impl BindContext { bound_internal_columns: BTreeMap::new(), aggregate_info: AggregateInfo::default(), windows: WindowInfo::default(), + lambda_info: LambdaInfo::default(), in_grouping: false, ctes_map: Box::default(), materialized_ctes: HashSet::new(), @@ -181,6 +186,7 @@ impl BindContext { bound_internal_columns: BTreeMap::new(), aggregate_info: Default::default(), windows: Default::default(), + lambda_info: LambdaInfo::default(), in_grouping: false, ctes_map: parent.ctes_map.clone(), materialized_ctes: parent.materialized_ctes.clone(), diff --git a/src/query/sql/src/planner/binder/delete.rs b/src/query/sql/src/planner/binder/delete.rs index c8e3b9d491360..c3ba76058fe03 100644 --- a/src/query/sql/src/planner/binder/delete.rs +++ b/src/query/sql/src/planner/binder/delete.rs @@ -200,7 +200,8 @@ impl Binder { ScalarExpr::BoundColumnRef(_) | ScalarExpr::ConstantExpr(_) | ScalarExpr::WindowFunction(_) - | ScalarExpr::AggregateFunction(_) => {} + | ScalarExpr::AggregateFunction(_) + | ScalarExpr::LambdaFunction(_) => {} } Ok(()) } diff --git a/src/query/sql/src/planner/binder/lambda.rs b/src/query/sql/src/planner/binder/lambda.rs new file mode 100644 index 0000000000000..9461c270ce989 --- /dev/null +++ b/src/query/sql/src/planner/binder/lambda.rs @@ -0,0 +1,278 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use common_exception::Result; + +use super::select::SelectList; +use crate::optimizer::SExpr; +use crate::plans::AggregateFunction; +use crate::plans::BoundColumnRef; +use crate::plans::CastExpr; +use crate::plans::EvalScalar; +use crate::plans::FunctionCall; +use crate::plans::Lambda; +use crate::plans::LambdaFunc; +use crate::plans::ScalarExpr; +use crate::plans::ScalarItem; +use crate::plans::WindowFunc; +use crate::plans::WindowOrderBy; +use crate::BindContext; +use crate::Binder; +use crate::ColumnBinding; +use crate::MetadataRef; +use crate::Visibility; + +#[derive(Default, Clone, PartialEq, Eq, Debug)] +pub struct LambdaInfo { + /// Arguments of lambda functions + pub lambda_arguments: Vec, + /// Lambda functions + pub lambda_functions: Vec, + /// Mapping: (lambda function display name) -> (derived column ref) + /// This is used to generate column in projection. + pub lambda_functions_map: HashMap, +} + +pub(super) struct LambdaRewriter<'a> { + pub bind_context: &'a mut BindContext, + pub metadata: MetadataRef, +} + +impl<'a> LambdaRewriter<'a> { + pub fn new(bind_context: &'a mut BindContext, metadata: MetadataRef) -> Self { + Self { + bind_context, + metadata, + } + } + + pub fn visit(&mut self, scalar: &ScalarExpr) -> Result { + match scalar { + ScalarExpr::BoundColumnRef(_) => Ok(scalar.clone()), + ScalarExpr::ConstantExpr(_) => Ok(scalar.clone()), + ScalarExpr::FunctionCall(func) => { + let new_args = func + .arguments + .iter() + .map(|arg| self.visit(arg)) + .collect::>>()?; + Ok(FunctionCall { + span: func.span, + func_name: func.func_name.clone(), + params: func.params.clone(), + arguments: new_args, + } + .into()) + } + ScalarExpr::CastExpr(cast) => Ok(CastExpr { + span: cast.span, + is_try: cast.is_try, + argument: Box::new(self.visit(&cast.argument)?), + target_type: cast.target_type.clone(), + } + .into()), + + // TODO(leiysky): should we recursively process subquery here? + ScalarExpr::SubqueryExpr(_) => Ok(scalar.clone()), + + ScalarExpr::AggregateFunction(agg_func) => { + let new_args = agg_func + .args + .iter() + .map(|arg| self.visit(arg)) + .collect::>>()?; + Ok(AggregateFunction { + func_name: agg_func.func_name.clone(), + distinct: agg_func.distinct, + params: agg_func.params.clone(), + args: new_args, + return_type: agg_func.return_type.clone(), + display_name: agg_func.display_name.clone(), + } + .into()) + } + + ScalarExpr::WindowFunction(window) => { + let new_partition_by = window + .partition_by + .iter() + .map(|partition_by| self.visit(partition_by)) + .collect::>>()?; + + let mut new_order_by = Vec::with_capacity(window.order_by.len()); + for order_by in window.order_by.iter() { + new_order_by.push(WindowOrderBy { + expr: self.visit(&order_by.expr)?, + asc: order_by.asc, + nulls_first: order_by.nulls_first, + }); + } + + Ok(WindowFunc { + span: window.span, + display_name: window.display_name.clone(), + partition_by: new_partition_by, + func: window.func.clone(), + order_by: new_order_by, + frame: window.frame.clone(), + } + .into()) + } + + ScalarExpr::LambdaFunction(lambda_func) => { + let mut replaced_args = Vec::with_capacity(lambda_func.args.len()); + for (i, arg) in lambda_func.args.iter().enumerate() { + let new_arg = self.visit(arg)?; + if let ScalarExpr::LambdaFunction(_) = new_arg { + replaced_args.push(new_arg); + continue; + } + + let replaced_arg = if let ScalarExpr::BoundColumnRef(ref column_ref) = new_arg { + column_ref.clone() + } else { + let name = format!("{}_arg_{}", &lambda_func.display_name, i); + let index = self + .metadata + .write() + .add_derived_column(name.clone(), new_arg.data_type()?); + + // Generate a ColumnBinding for each argument of lambda function + let column = ColumnBinding { + database_name: None, + table_name: None, + column_position: None, + table_index: None, + column_name: name, + index, + data_type: Box::new(new_arg.data_type()?), + visibility: Visibility::Visible, + virtual_computed_expr: None, + }; + + BoundColumnRef { + span: new_arg.span(), + column, + } + }; + + self.bind_context + .lambda_info + .lambda_arguments + .push(ScalarItem { + index: replaced_arg.column.index, + scalar: new_arg, + }); + replaced_args.push(replaced_arg.into()); + } + + let index = self + .metadata + .write() + .add_derived_column(lambda_func.display_name.clone(), scalar.data_type()?); + + let column = ColumnBinding { + database_name: None, + table_name: None, + column_position: None, + table_index: None, + column_name: lambda_func.display_name.clone(), + index, + data_type: Box::new(scalar.data_type()?), + visibility: Visibility::Visible, + virtual_computed_expr: None, + }; + + let replaced_column = BoundColumnRef { + span: scalar.span(), + column, + }; + + let replaced_lambda = LambdaFunc { + span: lambda_func.span, + func_name: lambda_func.func_name.clone(), + display_name: lambda_func.display_name.clone(), + args: replaced_args, + params: lambda_func.params.clone(), + lambda_expr: lambda_func.lambda_expr.clone(), + return_type: lambda_func.return_type.clone(), + }; + + self.bind_context + .lambda_info + .lambda_functions_map + .insert(lambda_func.display_name.clone(), replaced_column); + self.bind_context + .lambda_info + .lambda_functions + .push(ScalarItem { + index, + scalar: replaced_lambda.clone().into(), + }); + + Ok(replaced_lambda.into()) + } + } + } +} + +impl Binder { + /// Analyze lambda functions in select clause, this will rewrite lambda functions. + /// See [`LambdaRewriter`] for more details. + pub(crate) fn analyze_lambda( + &mut self, + bind_context: &mut BindContext, + select_list: &mut SelectList, + ) -> Result<()> { + for item in select_list.items.iter_mut() { + let mut rewriter = LambdaRewriter::new(bind_context, self.metadata.clone()); + let new_scalar = rewriter.visit(&item.scalar)?; + item.scalar = new_scalar; + } + + Ok(()) + } + + #[async_backtrace::framed] + pub async fn bind_lambda( + &mut self, + bind_context: &mut BindContext, + child: SExpr, + ) -> Result { + let lambda_info = &bind_context.lambda_info; + if lambda_info.lambda_functions.is_empty() { + return Ok(child); + } + + let mut new_expr = child; + if !lambda_info.lambda_arguments.is_empty() { + let mut scalar_items = lambda_info.lambda_arguments.clone(); + scalar_items.sort_by_key(|item| item.index); + let eval_scalar = EvalScalar { + items: scalar_items, + }; + new_expr = SExpr::create_unary(Arc::new(eval_scalar.into()), Arc::new(new_expr)); + } + + let lambda_plan = Lambda { + items: lambda_info.lambda_functions.clone(), + }; + new_expr = SExpr::create_unary(Arc::new(lambda_plan.into()), Arc::new(new_expr)); + + Ok(new_expr) + } +} diff --git a/src/query/sql/src/planner/binder/mod.rs b/src/query/sql/src/planner/binder/mod.rs index c93f008b06b3a..9564643c99672 100644 --- a/src/query/sql/src/planner/binder/mod.rs +++ b/src/query/sql/src/planner/binder/mod.rs @@ -28,6 +28,7 @@ mod insert; mod internal_column_factory; mod join; mod kill; +mod lambda; mod limit; mod location; mod presign; diff --git a/src/query/sql/src/planner/binder/project_set.rs b/src/query/sql/src/planner/binder/project_set.rs index dd5911f87f203..395f470488253 100644 --- a/src/query/sql/src/planner/binder/project_set.rs +++ b/src/query/sql/src/planner/binder/project_set.rs @@ -16,6 +16,7 @@ use std::sync::Arc; use common_ast::ast::Expr; use common_ast::ast::Identifier; +use common_ast::ast::Lambda; use common_ast::ast::Literal; use common_ast::ast::Window; use common_ast::Visitor; @@ -52,6 +53,7 @@ impl<'a> Visitor<'a> for SrfCollector { args: &'a [Expr], params: &'a [Literal], over: &'a Option, + lambda: &'a Option, ) { if BUILTIN_FUNCTIONS .get_property(&name.name) @@ -66,6 +68,7 @@ impl<'a> Visitor<'a> for SrfCollector { args: args.to_vec(), params: params.to_vec(), window: over.clone(), + lambda: lambda.clone(), }); } else { for arg in args.iter() { diff --git a/src/query/sql/src/planner/binder/scalar_common.rs b/src/query/sql/src/planner/binder/scalar_common.rs index 129dbcf16402b..488ad58637722 100644 --- a/src/query/sql/src/planner/binder/scalar_common.rs +++ b/src/query/sql/src/planner/binder/scalar_common.rs @@ -202,6 +202,10 @@ pub fn prune_by_children(scalar: &ScalarExpr, columns: &HashSet) -> .args .iter() .all(|arg| prune_by_children(arg, columns)), + ScalarExpr::LambdaFunction(scalar) => scalar + .args + .iter() + .all(|arg| prune_by_children(arg, columns)), ScalarExpr::FunctionCall(scalar) => scalar .arguments .iter() diff --git a/src/query/sql/src/planner/binder/scalar_visitor.rs b/src/query/sql/src/planner/binder/scalar_visitor.rs index b1fc549f012b4..87a94af703724 100644 --- a/src/query/sql/src/planner/binder/scalar_visitor.rs +++ b/src/query/sql/src/planner/binder/scalar_visitor.rs @@ -53,6 +53,11 @@ pub trait ScalarVisitor: Sized { stack.push(RecursionProcessing::Call(arg)); } } + ScalarExpr::LambdaFunction(func) => { + for arg in &func.args { + stack.push(RecursionProcessing::Call(arg)); + } + } ScalarExpr::WindowFunction(WindowFunc { func, partition_by, diff --git a/src/query/sql/src/planner/binder/select.rs b/src/query/sql/src/planner/binder/select.rs index 7f4d3e3d81161..05884ebe17032 100644 --- a/src/query/sql/src/planner/binder/select.rs +++ b/src/query/sql/src/planner/binder/select.rs @@ -165,6 +165,9 @@ impl Binder { .normalize_select_list(&mut from_context, &stmt.select_list) .await?; + // analyze lambda + self.analyze_lambda(&mut from_context, &mut select_list)?; + // This will potentially add some alias group items to `from_context` if find some. if let Some(group_by) = stmt.group_by.as_ref() { self.analyze_group_items(&mut from_context, &select_list, group_by) @@ -233,6 +236,10 @@ impl Binder { )?; } + if !from_context.lambda_info.lambda_functions.is_empty() { + s_expr = self.bind_lambda(&mut from_context, s_expr).await?; + } + if !from_context.aggregate_info.aggregate_functions.is_empty() || !from_context.aggregate_info.group_items.is_empty() { @@ -920,6 +927,7 @@ impl<'a> SelectRewriter<'a> { args, params: vec![], window: None, + lambda: None, }), alias, } diff --git a/src/query/sql/src/planner/binder/sort.rs b/src/query/sql/src/planner/binder/sort.rs index 40ddc2a86e15d..413124e281ede 100644 --- a/src/query/sql/src/planner/binder/sort.rs +++ b/src/query/sql/src/planner/binder/sort.rs @@ -35,6 +35,7 @@ use crate::plans::BoundColumnRef; use crate::plans::CastExpr; use crate::plans::EvalScalar; use crate::plans::FunctionCall; +use crate::plans::LambdaFunc; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; use crate::plans::Sort; @@ -341,6 +342,24 @@ impl Binder { return_type: return_type.clone(), })) } + ScalarExpr::LambdaFunction(lambda_func) => { + let args = lambda_func + .args + .iter() + .map(|arg| { + self.rewrite_scalar_with_replacement(bind_context, arg, replacement_fn) + }) + .collect::>>()?; + Ok(ScalarExpr::LambdaFunction(LambdaFunc { + span: lambda_func.span, + func_name: lambda_func.func_name.clone(), + display_name: lambda_func.display_name.clone(), + args, + params: lambda_func.params.clone(), + lambda_expr: lambda_func.lambda_expr.clone(), + return_type: lambda_func.return_type.clone(), + })) + } window @ ScalarExpr::WindowFunction(_) => { let mut rewriter = WindowRewriter::new(bind_context, self.metadata.clone()); rewriter.visit(window) diff --git a/src/query/sql/src/planner/binder/table.rs b/src/query/sql/src/planner/binder/table.rs index 52cdcf1a7b54b..52bffa4adbafb 100644 --- a/src/query/sql/src/planner/binder/table.rs +++ b/src/query/sql/src/planner/binder/table.rs @@ -489,6 +489,7 @@ impl Binder { params: vec![], args: params.clone(), window: None, + lambda: None, }), alias: None, }], @@ -772,6 +773,7 @@ impl Binder { columns: vec![], aggregate_info: Default::default(), windows: Default::default(), + lambda_info: Default::default(), in_grouping: false, ctes_map: Box::default(), materialized_ctes: HashSet::new(), diff --git a/src/query/sql/src/planner/binder/window.rs b/src/query/sql/src/planner/binder/window.rs index e07c04aab82d1..ebdefdafcccb8 100644 --- a/src/query/sql/src/planner/binder/window.rs +++ b/src/query/sql/src/planner/binder/window.rs @@ -29,6 +29,7 @@ use crate::plans::BoundColumnRef; use crate::plans::CastExpr; use crate::plans::FunctionCall; use crate::plans::LagLeadFunction; +use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; @@ -305,6 +306,24 @@ impl<'a> WindowRewriter<'a> { self.in_window = false; Ok(scalar) } + + ScalarExpr::LambdaFunction(lambda_func) => { + let new_args = lambda_func + .args + .iter() + .map(|arg| self.visit(arg)) + .collect::>>()?; + Ok(LambdaFunc { + span: lambda_func.span, + func_name: lambda_func.func_name.clone(), + display_name: lambda_func.display_name.clone(), + args: new_args, + params: lambda_func.params.clone(), + lambda_expr: lambda_func.lambda_expr.clone(), + return_type: lambda_func.return_type.clone(), + } + .into()) + } } } diff --git a/src/query/sql/src/planner/expression_parser.rs b/src/query/sql/src/planner/expression_parser.rs index b8061b5df580a..d1cc963df7a81 100644 --- a/src/query/sql/src/planner/expression_parser.rs +++ b/src/query/sql/src/planner/expression_parser.rs @@ -26,6 +26,7 @@ use common_catalog::table::Table; use common_catalog::table_context::TableContext; use common_exception::ErrorCode; use common_exception::Result; +use common_expression::infer_schema_type; use common_expression::infer_table_schema; use common_expression::types::DataType; use common_expression::ConstantFolder; @@ -44,11 +45,13 @@ use common_settings::Settings; use parking_lot::RwLock; use crate::binder::ColumnBindingBuilder; +use crate::binder::ExprContext; use crate::planner::binder::BindContext; use crate::planner::semantic::NameResolutionContext; use crate::planner::semantic::TypeChecker; use crate::plans::CastExpr; use crate::BaseTableColumn; +use crate::ColumnBinding; use crate::ColumnEntry; use crate::IdentifierNormalizer; use crate::Metadata; @@ -337,6 +340,54 @@ pub fn parse_computed_expr_to_string( Ok(format!("{:#}", ast)) } +pub fn parse_lambda_expr( + ctx: Arc, + column_name: &str, + data_type: &DataType, + ast: &AExpr, +) -> Result> { + let settings = Settings::create("".to_string()); + let mut bind_context = BindContext::new(); + let mut metadata = Metadata::default(); + + bind_context.set_expr_context(ExprContext::InLambdaFunction); + bind_context.add_column_binding(ColumnBinding { + database_name: None, + table_name: None, + column_position: None, + table_index: None, + column_name: column_name.to_string(), + index: 0, + data_type: Box::new(data_type.clone()), + visibility: Visibility::Visible, + virtual_computed_expr: None, + }); + + let table_type = infer_schema_type(data_type)?; + metadata.add_base_table_column( + column_name.to_string(), + table_type, + 0, + None, + None, + None, + None, + ); + + let name_resolution_ctx = NameResolutionContext::try_from(settings.as_ref())?; + let mut type_checker = TypeChecker::new( + &mut bind_context, + ctx.clone(), + &name_resolution_ctx, + Arc::new(RwLock::new(metadata)), + &[], + false, + false, + ); + + block_in_place(|| Handle::current().block_on(type_checker.resolve(ast))) +} + #[derive(Default)] struct DummyTable { info: TableInfo, diff --git a/src/query/sql/src/planner/format/display_rel_operator.rs b/src/query/sql/src/planner/format/display_rel_operator.rs index 140a7b23199f0..ddbc15435d96b 100644 --- a/src/query/sql/src/planner/format/display_rel_operator.rs +++ b/src/query/sql/src/planner/format/display_rel_operator.rs @@ -78,6 +78,7 @@ impl Display for FormatContext { RelOperator::ProjectSet(_) => write!(f, "ProjectSet"), RelOperator::CteScan(_) => write!(f, "CteScan"), RelOperator::MaterializedCte(_) => write!(f, "MaterializedCte"), + RelOperator::Lambda(_) => write!(f, "Lambda"), }, Self::Text(text) => write!(f, "{}", text), } @@ -102,6 +103,7 @@ pub fn format_scalar(scalar: &ScalarExpr) -> String { ScalarExpr::ConstantExpr(constant) => constant.value.to_string(), ScalarExpr::WindowFunction(win) => win.display_name.clone(), ScalarExpr::AggregateFunction(agg) => agg.display_name.clone(), + ScalarExpr::LambdaFunction(lambda) => lambda.display_name.clone(), ScalarExpr::FunctionCall(func) => { format!( "{}({})", diff --git a/src/query/sql/src/planner/optimizer/cost/cost_model.rs b/src/query/sql/src/planner/optimizer/cost/cost_model.rs index c8fac21b97466..72f295b622b4c 100644 --- a/src/query/sql/src/planner/optimizer/cost/cost_model.rs +++ b/src/query/sql/src/planner/optimizer/cost/cost_model.rs @@ -51,6 +51,7 @@ fn compute_cost_impl(memo: &Memo, m_expr: &MExpr) -> Result { | RelOperator::Window(_) | RelOperator::Sort(_) | RelOperator::ProjectSet(_) + | RelOperator::Lambda(_) | RelOperator::Limit(_) => compute_cost_unary_common_operator(memo, m_expr), _ => Err(ErrorCode::Internal("Cannot compute cost from logical plan")), diff --git a/src/query/sql/src/planner/optimizer/format.rs b/src/query/sql/src/planner/optimizer/format.rs index ef00883f35103..82332fae31327 100644 --- a/src/query/sql/src/planner/optimizer/format.rs +++ b/src/query/sql/src/planner/optimizer/format.rs @@ -51,6 +51,7 @@ pub fn display_rel_op(rel_op: &RelOperator) -> String { RelOperator::Window(_) => "WindowFunc".to_string(), RelOperator::CteScan(_) => "CteScan".to_string(), RelOperator::MaterializedCte(_) => "MaterializedCte".to_string(), + RelOperator::Lambda(_) => "LambdaFunc".to_string(), } } diff --git a/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs b/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs index 182b3119771ee..e7c7b85596e8b 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs @@ -25,6 +25,7 @@ use crate::plans::Aggregate; use crate::plans::CteScan; use crate::plans::DummyTableScan; use crate::plans::EvalScalar; +use crate::plans::Lambda; use crate::plans::ProjectSet; use crate::plans::RelOperator; use crate::ColumnEntry; @@ -333,6 +334,25 @@ impl UnusedColumnPruner { )) } + RelOperator::Lambda(p) => { + let mut used = vec![]; + // Keep all columns, as some lambda functions may be arguments to other lambda functions. + for s in p.items.iter() { + used.push(s.clone()); + s.scalar.used_columns().iter().for_each(|c| { + required.insert(*c); + }) + } + if used.is_empty() { + self.keep_required_columns(expr.child(0)?, required) + } else { + Ok(SExpr::create_unary( + Arc::new(RelOperator::Lambda(Lambda { items: used })), + Arc::new(self.keep_required_columns(expr.child(0)?, required)?), + )) + } + } + RelOperator::DummyTableScan(_) => Ok(expr.clone()), _ => Err(ErrorCode::Internal( diff --git a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs index e045c33fba526..2944e4795e6f9 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs @@ -178,6 +178,7 @@ impl SubqueryRewriter { ScalarExpr::ConstantExpr(_) => Ok((scalar.clone(), s_expr.clone())), ScalarExpr::WindowFunction(_) => Ok((scalar.clone(), s_expr.clone())), ScalarExpr::AggregateFunction(_) => Ok((scalar.clone(), s_expr.clone())), + ScalarExpr::LambdaFunction(_) => Ok((scalar.clone(), s_expr.clone())), ScalarExpr::FunctionCall(func) => { let mut args = vec![]; let mut s_expr = s_expr.clone(); diff --git a/src/query/sql/src/planner/optimizer/hyper_dp/dphyp.rs b/src/query/sql/src/planner/optimizer/hyper_dp/dphyp.rs index 847c7cf369d59..2e803e947cdc8 100644 --- a/src/query/sql/src/planner/optimizer/hyper_dp/dphyp.rs +++ b/src/query/sql/src/planner/optimizer/hyper_dp/dphyp.rs @@ -158,6 +158,7 @@ impl DPhpy { left_op, RelOperator::EvalScalar(_) | RelOperator::Aggregate(_) + | RelOperator::Lambda(_) | RelOperator::Sort(_) | RelOperator::Limit(_) | RelOperator::ProjectSet(_) @@ -169,6 +170,7 @@ impl DPhpy { right_op, RelOperator::EvalScalar(_) | RelOperator::Aggregate(_) + | RelOperator::Lambda(_) | RelOperator::Sort(_) | RelOperator::Limit(_) | RelOperator::ProjectSet(_) @@ -213,6 +215,7 @@ impl DPhpy { } RelOperator::ProjectSet(_) | RelOperator::Aggregate(_) + | RelOperator::Lambda(_) | RelOperator::Sort(_) | RelOperator::Limit(_) | RelOperator::EvalScalar(_) diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/filter_join/derive_filter.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/filter_join/derive_filter.rs index 88d4a6cd72a5d..1cc581958e366 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/filter_join/derive_filter.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/filter_join/derive_filter.rs @@ -166,6 +166,11 @@ fn replace_column(scalar: &mut ScalarExpr, col_to_scalar: &HashMap<&IndexType, & replace_column(arg, col_to_scalar) } } + ScalarExpr::LambdaFunction(expr) => { + for arg in expr.args.iter_mut() { + replace_column(arg, col_to_scalar) + } + } ScalarExpr::CastExpr(expr) => { replace_column(&mut expr.argument, col_to_scalar); } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs index 3b6d18bf9959f..bad89e0860d6b 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs @@ -29,6 +29,7 @@ use crate::plans::EvalScalar; use crate::plans::Filter; use crate::plans::FunctionCall; use crate::plans::LagLeadFunction; +use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::PatternPlan; use crate::plans::RelOp; @@ -198,6 +199,23 @@ impl RulePushDownFilterEvalScalar { func_name: func.func_name.clone(), })) } + ScalarExpr::LambdaFunction(lambda_func) => { + let args = lambda_func + .args + .iter() + .map(|arg| Self::replace_predicate(arg, items)) + .collect::>>()?; + + Ok(ScalarExpr::LambdaFunction(LambdaFunc { + span: lambda_func.span, + func_name: lambda_func.func_name.clone(), + display_name: lambda_func.display_name.clone(), + args, + params: lambda_func.params.clone(), + lambda_expr: lambda_func.lambda_expr.clone(), + return_type: lambda_func.return_type.clone(), + })) + } ScalarExpr::CastExpr(cast) => { let arg = Self::replace_predicate(&cast.argument, items)?; Ok(ScalarExpr::CastExpr(CastExpr { diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs index c32237db42669..c19fcc1b8fc94 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs @@ -27,6 +27,7 @@ use crate::plans::CastExpr; use crate::plans::Filter; use crate::plans::FunctionCall; use crate::plans::LagLeadFunction; +use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::PatternPlan; use crate::plans::RelOp; @@ -210,6 +211,23 @@ impl RulePushDownFilterScan { display_name: agg_func.display_name.clone(), })) } + ScalarExpr::LambdaFunction(lambda_func) => { + let args = lambda_func + .args + .iter() + .map(|arg| Self::replace_view_column(arg, table_entries, column_entries)) + .collect::>>()?; + + Ok(ScalarExpr::LambdaFunction(LambdaFunc { + span: lambda_func.span, + func_name: lambda_func.func_name.clone(), + display_name: lambda_func.display_name.clone(), + args, + params: lambda_func.params.clone(), + lambda_expr: lambda_func.lambda_expr.clone(), + return_type: lambda_func.return_type.clone(), + })) + } ScalarExpr::FunctionCall(func) => { let arguments = func .arguments diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_union.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_union.rs index d059a8562eb15..94a61fe77a431 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_union.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_union.rs @@ -29,6 +29,7 @@ use crate::plans::CastExpr; use crate::plans::Filter; use crate::plans::FunctionCall; use crate::plans::LagLeadFunction; +use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::PatternPlan; use crate::plans::RelOp; @@ -244,6 +245,23 @@ fn replace_column_binding( .map(|arg| replace_column_binding(index_pairs, arg)) .collect::>>()?, })), + ScalarExpr::LambdaFunction(lambda_func) => { + let args = lambda_func + .args + .into_iter() + .map(|arg| replace_column_binding(index_pairs, arg)) + .collect::>>()?; + + Ok(ScalarExpr::LambdaFunction(LambdaFunc { + span: lambda_func.span, + func_name: lambda_func.func_name.clone(), + display_name: lambda_func.display_name.clone(), + args, + params: lambda_func.params.clone(), + lambda_expr: lambda_func.lambda_expr.clone(), + return_type: lambda_func.return_type, + })) + } ScalarExpr::CastExpr(expr) => Ok(ScalarExpr::CastExpr(CastExpr { span: expr.span, is_try: expr.is_try, diff --git a/src/query/sql/src/planner/optimizer/s_expr.rs b/src/query/sql/src/planner/optimizer/s_expr.rs index d4d859fcf23f9..21eb93fc341fb 100644 --- a/src/query/sql/src/planner/optimizer/s_expr.rs +++ b/src/query/sql/src/planner/optimizer/s_expr.rs @@ -290,6 +290,10 @@ fn find_subquery(rel_op: &RelOperator) -> bool { .srfs .iter() .any(|expr| find_subquery_in_expr(&expr.scalar)), + RelOperator::Lambda(op) => op + .items + .iter() + .any(|expr| find_subquery_in_expr(&expr.scalar)), } } @@ -305,6 +309,7 @@ fn find_subquery_in_expr(expr: &ScalarExpr) -> bool { || expr.order_by.iter().any(|o| find_subquery_in_expr(&o.expr)) } ScalarExpr::AggregateFunction(expr) => expr.args.iter().any(find_subquery_in_expr), + ScalarExpr::LambdaFunction(expr) => expr.args.iter().any(find_subquery_in_expr), ScalarExpr::FunctionCall(expr) => expr.arguments.iter().any(find_subquery_in_expr), ScalarExpr::CastExpr(expr) => find_subquery_in_expr(&expr.argument), ScalarExpr::SubqueryExpr(_) => true, diff --git a/src/query/sql/src/planner/plans/lambda.rs b/src/query/sql/src/planner/plans/lambda.rs new file mode 100644 index 0000000000000..8e493275a6428 --- /dev/null +++ b/src/query/sql/src/planner/plans/lambda.rs @@ -0,0 +1,107 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use common_catalog::table_context::TableContext; +use common_exception::Result; + +use crate::optimizer::ColumnSet; +use crate::optimizer::PhysicalProperty; +use crate::optimizer::RelExpr; +use crate::optimizer::RelationalProperty; +use crate::optimizer::RequiredProperty; +use crate::optimizer::StatInfo; +use crate::plans::Operator; +use crate::plans::RelOp; +use crate::plans::ScalarItem; + +/// `Lambda` is a plan that evaluate a series of lambda functions. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Lambda { + pub items: Vec, +} + +impl Lambda { + pub fn used_columns(&self) -> Result { + let mut used_columns = ColumnSet::new(); + for item in self.items.iter() { + used_columns.insert(item.index); + used_columns.extend(item.scalar.used_columns()); + } + Ok(used_columns) + } +} + +impl Operator for Lambda { + fn rel_op(&self) -> RelOp { + RelOp::Lambda + } + + fn derive_relational_prop( + &self, + rel_expr: &RelExpr, + ) -> common_exception::Result> { + let input_prop = rel_expr.derive_relational_prop_child(0)?; + + // Derive output columns + let mut output_columns = input_prop.output_columns.clone(); + for item in self.items.iter() { + output_columns.insert(item.index); + } + + // Derive outer columns + let mut outer_columns = input_prop.outer_columns.clone(); + for item in self.items.iter() { + let used_columns = item.scalar.used_columns(); + let outer = used_columns + .difference(&output_columns) + .cloned() + .collect::(); + outer_columns = outer_columns.union(&outer).cloned().collect(); + } + outer_columns = outer_columns.difference(&output_columns).cloned().collect(); + + // Derive used columns + let mut used_columns = self.used_columns()?; + used_columns.extend(input_prop.used_columns.clone()); + + Ok(Arc::new(RelationalProperty { + output_columns, + outer_columns, + used_columns, + })) + } + + fn derive_physical_prop( + &self, + rel_expr: &RelExpr, + ) -> common_exception::Result { + rel_expr.derive_physical_prop_child(0) + } + + fn derive_cardinality(&self, rel_expr: &RelExpr) -> common_exception::Result> { + rel_expr.derive_cardinality_child(0) + } + + fn compute_required_prop_child( + &self, + _ctx: Arc, + _rel_expr: &RelExpr, + _child_index: usize, + required: &RequiredProperty, + ) -> common_exception::Result { + Ok(required.clone()) + } +} diff --git a/src/query/sql/src/planner/plans/mod.rs b/src/query/sql/src/planner/plans/mod.rs index b28e63a862a6e..09f85641c3dec 100644 --- a/src/query/sql/src/planner/plans/mod.rs +++ b/src/query/sql/src/planner/plans/mod.rs @@ -26,6 +26,7 @@ mod filter; pub mod insert; mod join; mod kill; +mod lambda; mod limit; mod materialized_cte; mod operator; @@ -62,6 +63,7 @@ pub use insert::Insert; pub use insert::InsertInputSource; pub use join::*; pub use kill::KillPlan; +pub use lambda::*; pub use limit::*; pub use materialized_cte::MaterializedCte; pub use operator::*; diff --git a/src/query/sql/src/planner/plans/operator.rs b/src/query/sql/src/planner/plans/operator.rs index a23c80ce33aa7..4a94eed75b2b4 100644 --- a/src/query/sql/src/planner/plans/operator.rs +++ b/src/query/sql/src/planner/plans/operator.rs @@ -37,6 +37,7 @@ use crate::plans::materialized_cte::MaterializedCte; use crate::plans::runtime_filter_source::RuntimeFilterSource; use crate::plans::CteScan; use crate::plans::Exchange; +use crate::plans::Lambda; use crate::plans::ProjectSet; use crate::plans::Window; @@ -80,6 +81,7 @@ pub enum RelOp { Window, ProjectSet, MaterializedCte, + Lambda, // Pattern Pattern, @@ -103,6 +105,7 @@ pub enum RelOperator { Window(Window), ProjectSet(ProjectSet), MaterializedCte(MaterializedCte), + Lambda(Lambda), Pattern(PatternPlan), } @@ -126,6 +129,7 @@ impl Operator for RelOperator { RelOperator::Window(rel_op) => rel_op.rel_op(), RelOperator::CteScan(rel_op) => rel_op.rel_op(), RelOperator::MaterializedCte(rel_op) => rel_op.rel_op(), + RelOperator::Lambda(rel_op) => rel_op.rel_op(), } } @@ -147,6 +151,7 @@ impl Operator for RelOperator { RelOperator::Window(rel_op) => rel_op.derive_relational_prop(rel_expr), RelOperator::CteScan(rel_op) => rel_op.derive_relational_prop(rel_expr), RelOperator::MaterializedCte(rel_op) => rel_op.derive_relational_prop(rel_expr), + RelOperator::Lambda(rel_op) => rel_op.derive_relational_prop(rel_expr), } } @@ -168,6 +173,7 @@ impl Operator for RelOperator { RelOperator::Window(rel_op) => rel_op.derive_physical_prop(rel_expr), RelOperator::CteScan(rel_op) => rel_op.derive_physical_prop(rel_expr), RelOperator::MaterializedCte(rel_op) => rel_op.derive_physical_prop(rel_expr), + RelOperator::Lambda(rel_op) => rel_op.derive_physical_prop(rel_expr), } } @@ -189,6 +195,7 @@ impl Operator for RelOperator { RelOperator::Window(rel_op) => rel_op.derive_cardinality(rel_expr), RelOperator::CteScan(rel_op) => rel_op.derive_cardinality(rel_expr), RelOperator::MaterializedCte(rel_op) => rel_op.derive_cardinality(rel_expr), + RelOperator::Lambda(rel_op) => rel_op.derive_cardinality(rel_expr), } } @@ -248,6 +255,9 @@ impl Operator for RelOperator { RelOperator::MaterializedCte(rel_op) => { rel_op.compute_required_prop_child(ctx, rel_expr, child_index, required) } + RelOperator::Lambda(rel_op) => { + rel_op.compute_required_prop_child(ctx, rel_expr, child_index, required) + } } } } @@ -552,3 +562,21 @@ impl TryFrom for ProjectSet { } } } + +impl From for RelOperator { + fn from(value: Lambda) -> Self { + Self::Lambda(value) + } +} + +impl TryFrom for Lambda { + type Error = ErrorCode; + + fn try_from(value: RelOperator) -> std::result::Result { + if let RelOperator::Lambda(value) = value { + Ok(value) + } else { + Err(ErrorCode::Internal("Cannot downcast RelOperator to Lambda")) + } + } +} diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index f9b9a1d30d0dc..b2f3b5a2a8bde 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -39,6 +39,7 @@ pub enum ScalarExpr { ConstantExpr(ConstantExpr), WindowFunction(WindowFunc), AggregateFunction(AggregateFunction), + LambdaFunction(LambdaFunc), FunctionCall(FunctionCall), CastExpr(CastExpr), SubqueryExpr(SubqueryExpr), @@ -70,6 +71,13 @@ impl ScalarExpr { } result } + ScalarExpr::LambdaFunction(scalar) => { + let mut result = ColumnSet::new(); + for scalar in &scalar.args { + result = result.union(&scalar.used_columns()).cloned().collect(); + } + result + } ScalarExpr::FunctionCall(scalar) => { let mut result = ColumnSet::new(); for scalar in &scalar.arguments { @@ -107,6 +115,13 @@ impl ScalarExpr { } Ok(result) } + ScalarExpr::LambdaFunction(scalar) => { + let mut result = vec![]; + for scalar in &scalar.args { + result.append(&mut scalar.used_tables(metadata.clone())?); + } + Ok(result) + } ScalarExpr::CastExpr(scalar) => scalar.argument.used_tables(metadata), ScalarExpr::WindowFunction(_) | ScalarExpr::SubqueryExpr(_) => { Err(ErrorCode::Unimplemented( @@ -142,8 +157,11 @@ impl ScalarExpr { ScalarExpr::WindowFunction(_) | ScalarExpr::AggregateFunction(_) | ScalarExpr::SubqueryExpr(_) => false, - ScalarExpr::FunctionCall(expr) => { - expr.arguments.iter().all(|arg| arg.valid_for_clustering()) + ScalarExpr::FunctionCall(func) => { + func.arguments.iter().all(|arg| arg.valid_for_clustering()) + } + ScalarExpr::LambdaFunction(func) => { + func.args.iter().all(|arg| arg.valid_for_clustering()) } ScalarExpr::CastExpr(expr) => expr.argument.valid_for_clustering(), } @@ -225,6 +243,24 @@ impl TryFrom for WindowFunc { } } +impl From for ScalarExpr { + fn from(v: LambdaFunc) -> Self { + Self::LambdaFunction(v) + } +} + +impl TryFrom for LambdaFunc { + type Error = ErrorCode; + + fn try_from(value: ScalarExpr) -> Result { + if let ScalarExpr::LambdaFunction(value) = value { + Ok(value) + } else { + Err(ErrorCode::Internal("Cannot downcast Scalar to LambdaFunc")) + } + } +} + impl From for ScalarExpr { fn from(v: FunctionCall) -> Self { Self::FunctionCall(v) @@ -413,6 +449,19 @@ pub struct WindowOrderBy { pub nulls_first: Option, } +#[derive(Clone, Debug, Educe)] +#[educe(PartialEq, Eq, Hash)] +pub struct LambdaFunc { + #[educe(PartialEq(ignore), Eq(ignore), Hash(ignore))] + pub span: Span, + pub func_name: String, + pub display_name: String, + pub args: Vec, + pub params: Vec<(String, DataType)>, + pub lambda_expr: Box, + pub return_type: Box, +} + #[derive(Clone, Debug, Educe)] #[educe(PartialEq, Eq, Hash)] pub struct FunctionCall { diff --git a/src/query/sql/src/planner/semantic/aggregating_index_rewriter.rs b/src/query/sql/src/planner/semantic/aggregating_index_rewriter.rs index adf5513e3ce17..3dcb60c263929 100644 --- a/src/query/sql/src/planner/semantic/aggregating_index_rewriter.rs +++ b/src/query/sql/src/planner/semantic/aggregating_index_rewriter.rs @@ -59,6 +59,7 @@ impl VisitorMut for AggregatingIndexRewriter { args: vec![], params: vec![], window: None, + lambda: None, }; } _ => {} diff --git a/src/query/sql/src/planner/semantic/distinct_to_groupby.rs b/src/query/sql/src/planner/semantic/distinct_to_groupby.rs index 60bb988cf7e5b..2a2c9a6d4c298 100644 --- a/src/query/sql/src/planner/semantic/distinct_to_groupby.rs +++ b/src/query/sql/src/planner/semantic/distinct_to_groupby.rs @@ -89,6 +89,7 @@ impl VisitorMut for DistinctToGroupBy { args: vec![], params: vec![], window: None, + lambda: None, }), alias: alias.clone(), }], diff --git a/src/query/sql/src/planner/semantic/grouping_check.rs b/src/query/sql/src/planner/semantic/grouping_check.rs index 25d9efda298ec..e7dfd614e4aa4 100644 --- a/src/query/sql/src/planner/semantic/grouping_check.rs +++ b/src/query/sql/src/planner/semantic/grouping_check.rs @@ -21,6 +21,7 @@ use crate::binder::Visibility; use crate::plans::BoundColumnRef; use crate::plans::CastExpr; use crate::plans::FunctionCall; +use crate::plans::LambdaFunc; use crate::plans::ScalarExpr; use crate::BindContext; @@ -99,6 +100,24 @@ impl<'a> GroupingChecker<'a> { } .into()) } + + ScalarExpr::LambdaFunction(lambda_func) => { + let args = lambda_func + .args + .iter() + .map(|arg| self.resolve(arg, span)) + .collect::>>()?; + Ok(LambdaFunc { + span: lambda_func.span, + func_name: lambda_func.func_name.clone(), + display_name: lambda_func.display_name.clone(), + args, + params: lambda_func.params.clone(), + lambda_expr: lambda_func.lambda_expr.clone(), + return_type: lambda_func.return_type.clone(), + } + .into()) + } ScalarExpr::CastExpr(cast) => Ok(CastExpr { span: cast.span, is_try: cast.is_try, diff --git a/src/query/sql/src/planner/semantic/lowering.rs b/src/query/sql/src/planner/semantic/lowering.rs index 960f9f2362227..666c170f25c54 100644 --- a/src/query/sql/src/planner/semantic/lowering.rs +++ b/src/query/sql/src/planner/semantic/lowering.rs @@ -215,6 +215,22 @@ impl ScalarExpr { data_type: (*agg.return_type).clone(), display_name: agg.display_name.clone(), }, + ScalarExpr::LambdaFunction(func) => RawExpr::ColumnRef { + span: None, + id: ColumnBinding { + database_name: None, + table_name: None, + table_index: None, + column_position: None, + column_name: func.display_name.clone(), + index: usize::MAX, + data_type: Box::new((*func.return_type).clone()), + visibility: Visibility::Visible, + virtual_computed_expr: None, + }, + data_type: (*func.return_type).clone(), + display_name: func.display_name.clone(), + }, ScalarExpr::FunctionCall(func) => RawExpr::FunctionCall { span: func.span, name: func.func_name.clone(), diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index d4fa55dd1eb62..06f900297058c 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -62,6 +62,7 @@ use common_functions::aggregates::AggregateCountFunction; use common_functions::aggregates::AggregateFunctionFactory; use common_functions::is_builtin_function; use common_functions::BUILTIN_FUNCTIONS; +use common_functions::GENERAL_LAMBDA_FUNCTIONS; use common_functions::GENERAL_WINDOW_FUNCTIONS; use common_users::UserApiProvider; use simsearch::SimSearch; @@ -74,6 +75,7 @@ use crate::binder::ColumnBindingBuilder; use crate::binder::ExprContext; use crate::binder::NameResolutionResult; use crate::optimizer::RelExpr; +use crate::parse_lambda_expr; use crate::planner::metadata::optimize_remove_count_args; use crate::plans::AggregateFunction; use crate::plans::BoundColumnRef; @@ -82,6 +84,7 @@ use crate::plans::ComparisonOp; use crate::plans::ConstantExpr; use crate::plans::FunctionCall; use crate::plans::LagLeadFunction; +use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::NtileFunction; use crate::plans::ScalarExpr; @@ -350,6 +353,7 @@ impl<'a> TypeChecker<'a> { args: args.iter().copied().cloned().collect(), params: vec![], window: None, + lambda: None, }) .await? } else { @@ -575,6 +579,7 @@ impl<'a> TypeChecker<'a> { args: vec![*operand.clone(), c.clone()], params: vec![], window: None, + lambda: None, }; arguments.push(equal_expr) } @@ -625,6 +630,7 @@ impl<'a> TypeChecker<'a> { args, params, window, + lambda, } => { let func_name = normalize_identifier(name, self.name_resolution_ctx).to_string(); let func_name = func_name.as_str(); @@ -639,6 +645,8 @@ impl<'a> TypeChecker<'a> { .all_function_names() .into_iter() .chain(AggregateFunctionFactory::instance().registered_names()) + .chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string)) + .chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string)) .chain( Self::all_rewritable_scalar_function() .iter() @@ -700,6 +708,15 @@ impl<'a> TypeChecker<'a> { let name = func_name.to_lowercase(); if GENERAL_WINDOW_FUNCTIONS.contains(&name.as_str()) { + if matches!( + self.bind_context.expr_context, + ExprContext::InLambdaFunction + ) { + return Err(ErrorCode::SemanticError( + "window functions can not be used in lambda function".to_string(), + ) + .set_span(*span)); + } // general window function if window.is_none() { return Err(ErrorCode::SemanticError(format!( @@ -712,6 +729,16 @@ impl<'a> TypeChecker<'a> { self.resolve_window(*span, display_name, window, func) .await? } else if AggregateFunctionFactory::instance().contains(&name) { + if matches!( + self.bind_context.expr_context, + ExprContext::InLambdaFunction + ) { + return Err(ErrorCode::SemanticError( + "aggregate functions can not be used in lambda function".to_string(), + ) + .set_span(*span)); + } + let in_window = self.in_window_function; self.in_window_function = self.in_window_function || window.is_some(); let (new_agg_func, data_type) = self @@ -728,6 +755,93 @@ impl<'a> TypeChecker<'a> { // aggregate function Box::new((new_agg_func.into(), data_type)) } + } else if GENERAL_LAMBDA_FUNCTIONS.contains(&name.as_str()) { + if matches!( + self.bind_context.expr_context, + ExprContext::InLambdaFunction + ) { + return Err(ErrorCode::SemanticError( + "lambda functions can not be used in lambda function".to_string(), + ) + .set_span(*span)); + } + if lambda.is_none() { + return Err(ErrorCode::SemanticError(format!( + "function {name} must have a lambda expression", + ))); + } + let lambda = lambda.as_ref().unwrap(); + + let params = lambda + .params + .iter() + .map(|param| param.name.clone()) + .collect::>(); + + // TODO: support multiple params + if params.len() != 1 { + return Err(ErrorCode::SemanticError(format!( + "incorrect number of parameters in lambda function, {name} expects 1 parameter", + ))); + } + + if args.len() != 1 { + return Err(ErrorCode::SemanticError(format!( + "invalid arguments for lambda function, {name} expects 1 argument" + ))); + } + let box (arg, arg_type) = self.resolve(args[0]).await?; + match arg_type.remove_nullable() { + // Empty array will always return an Empty array + DataType::EmptyArray => Box::new(( + ConstantExpr { + span: *span, + value: Scalar::EmptyArray, + } + .into(), + DataType::EmptyArray, + )), + DataType::Array(box inner_ty) => { + let box (lambda_expr, lambda_type) = parse_lambda_expr( + self.ctx.clone(), + ¶ms[0], + &inner_ty, + &lambda.expr, + )?; + + let return_type = if name == "array_filter" { + if lambda_type.remove_nullable() == DataType::Boolean { + arg_type + } else { + return Err(ErrorCode::SemanticError( + "invalid lambda function for `array_filter`, the result data type of lambda function must be boolean".to_string() + )); + } + } else if arg_type.is_nullable() { + DataType::Nullable(Box::new(DataType::Array(Box::new(lambda_type)))) + } else { + DataType::Array(Box::new(lambda_type)) + }; + Box::new(( + LambdaFunc { + span: *span, + func_name: name.clone(), + display_name: format!("{:#}", expr), + args: vec![arg], + params: vec![(params[0].clone(), inner_ty)], + lambda_expr: Box::new(lambda_expr), + return_type: Box::new(return_type.clone()), + } + .into(), + return_type, + )) + } + _ => { + return Err(ErrorCode::SemanticError( + "invalid arguments for lambda function, argument data type must be array".to_string() + )); + } + } } else { // Scalar function let params = params @@ -2029,6 +2143,7 @@ impl<'a> TypeChecker<'a> { args: vec![arg_x.clone()], params: vec![], window: None, + lambda: None, }) .await, ) @@ -2065,6 +2180,7 @@ impl<'a> TypeChecker<'a> { args: vec![(*arg).clone()], params: vec![], window: None, + lambda: None, }; new_args.push(is_not_null_expr); @@ -2947,6 +3063,7 @@ impl<'a> TypeChecker<'a> { args, params, window, + lambda, } => Ok(Expr::FunctionCall { span: *span, distinct: *distinct, @@ -2957,6 +3074,7 @@ impl<'a> TypeChecker<'a> { .collect::>>()?, params: params.clone(), window: window.clone(), + lambda: lambda.clone(), }), Expr::Case { span, diff --git a/src/query/sql/src/planner/semantic/window_check.rs b/src/query/sql/src/planner/semantic/window_check.rs index 96077dc0b1918..8379fbf474d91 100644 --- a/src/query/sql/src/planner/semantic/window_check.rs +++ b/src/query/sql/src/planner/semantic/window_check.rs @@ -49,6 +49,17 @@ impl<'a> WindowChecker<'a> { } .into()) } + ScalarExpr::LambdaFunction(lambda) => { + if let Some(column_ref) = self + .bind_context + .lambda_info + .lambda_functions_map + .get(&lambda.display_name) + { + return Ok(column_ref.clone().into()); + } + Err(ErrorCode::Internal("Window Check: Invalid lambda function")) + } ScalarExpr::CastExpr(cast) => Ok(CastExpr { span: cast.span, is_try: cast.is_try, diff --git a/src/query/sql/src/planner/udf_validator.rs b/src/query/sql/src/planner/udf_validator.rs index 8e770fbdd077a..ff9941231fc43 100644 --- a/src/query/sql/src/planner/udf_validator.rs +++ b/src/query/sql/src/planner/udf_validator.rs @@ -17,6 +17,7 @@ use std::collections::HashSet; use common_ast::ast::ColumnID; use common_ast::ast::Expr; use common_ast::ast::Identifier; +use common_ast::ast::Lambda; use common_ast::ast::Literal; use common_ast::ast::Window; use common_ast::walk_expr; @@ -89,6 +90,7 @@ impl<'ast> Visitor<'ast> for UDFValidator { args: &'ast [Expr], _params: &'ast [Literal], over: &'ast Option, + lambda: &'ast Option, ) { let name = name.to_string(); if !is_builtin_function(&name) && self.name.eq_ignore_ascii_case(&name) { @@ -120,5 +122,8 @@ impl<'ast> Visitor<'ast> for UDFValidator { } } } + if let Some(lambda) = lambda { + walk_expr(self, &lambda.expr) + } } } diff --git a/src/query/storages/system/src/query_profile_table.rs b/src/query/storages/system/src/query_profile_table.rs index 808fd346227d8..be641fe4c5172 100644 --- a/src/query/storages/system/src/query_profile_table.rs +++ b/src/query/storages/system/src/query_profile_table.rs @@ -66,6 +66,9 @@ fn encode_operator_attribute(attr: &OperatorAttribute) -> jsonb::Value { OperatorAttribute::ProjectSet(project_attr) => { (&serde_json::json!({ "functions": project_attr.functions })).into() } + OperatorAttribute::Lambda(lambda_attr) => { + (&serde_json::json!({ "scalars": lambda_attr.scalars })).into() + } OperatorAttribute::Limit(limit_attr) => (&serde_json::json!({ "limit": limit_attr.limit, "offset": limit_attr.offset, diff --git a/tests/sqllogictests/suites/query/02_function/02_0061_function_array b/tests/sqllogictests/suites/query/02_function/02_0061_function_array index c0c53d169ffad..5b061b0d48634 100644 --- a/tests/sqllogictests/suites/query/02_function/02_0061_function_array +++ b/tests/sqllogictests/suites/query/02_function/02_0061_function_array @@ -177,5 +177,51 @@ select array_sort(col1, 'asc', 'nulls fir') from t; statement error 1065 select array_sort(col1, 'asca', 'nulls first') from t; +query T +select array_transform([1, 2, NULL, 3], x -> x + 1) +---- +[2,3,NULL,4] + +query T +select array_transform(['data', 'a', 'b'], data -> CONCAT(data, 'bend')) +---- +['databend','abend','bbend'] + +query T +select array_apply(array_apply([5, NULL, 6], x -> COALESCE(x, 0) + 1), y -> y + 10) +---- +[16,11,17] + +query TT +select array_transform(col1, a -> a * 2), array_apply(col2, b -> upper(b)) from t +---- +[2,4,6,6] ['X','X','Y','Z'] + +statement error 1065 +select array_transform([1, 2], x -> y + 1) + +query T +select array_filter([5, -6, NULL, 7], x -> x > 0) +---- +[5,7] + +query T +select array_filter(['Hello', 'abc World'], x -> x LIKE '%World%'); +---- +['abc World'] + +query T +select array_filter(array_filter([2, 4, 3, 1, 20, 10, 3, 30], x -> x % 2 = 0), y -> y % 5 = 0) +---- +[20,10,30] + +query TT +select array_filter(col1, a -> a % 2 = 1), array_filter(col2, b -> b = 'x') from t +---- +[1,3,3] ['x','x'] + +statement error 1065 +select array_filter([1, 2], x -> x + 1) + statement ok DROP DATABASE array_func_test