Skip to content

Hacky extension point for aggregate function determination #2992

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 121 additions & 106 deletions sql/planbuilder/aggregates.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,8 @@ func (b *Builder) buildGroupingCols(fromScope, projScope *scope, groupby ast.Gro
// 3) an index into selects
// 4) a simple non-aggregate expression
groupings := make([]sql.Expression, 0)
if fromScope.groupBy == nil {
fromScope.initGroupBy()
}
fromScope.initGroupBy()

g := fromScope.groupBy
for _, e := range groupby {
var col scopeColumn
Expand Down Expand Up @@ -194,9 +193,7 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
// - grouping cols projection
// - aggregate expressions
// - output projection
if fromScope.groupBy == nil {
fromScope.initGroupBy()
}
fromScope.initGroupBy()

group := fromScope.groupBy
outScope := group.outScope
Expand Down Expand Up @@ -257,7 +254,10 @@ func (b *Builder) buildAggregation(fromScope, projScope *scope, groupingCols []s
return outScope
}

func isAggregateFunc(name string) bool {
// IsAggregateFunc is a hacky "extension point" to allow for other dialects to declare additional aggregate functions
var IsAggregateFunc = IsMySQLAggregateFuncName

func IsMySQLAggregateFuncName(name string) bool {
switch name {
case "avg", "bit_and", "bit_or", "bit_xor", "count",
"group_concat", "json_arrayagg", "json_objectagg",
Expand All @@ -278,111 +278,63 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
b.handleErr(err)
}

if inScope.groupBy == nil {
inScope.initGroupBy()
}
inScope.initGroupBy()
gb := inScope.groupBy

if strings.EqualFold(name, "count") {
if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
var agg sql.Aggregation
if e.Distinct {
agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64))
} else {
agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64))
}
b.qFlags.Set(sql.QFlagCountStar)
aggName := strings.ToLower(agg.String())
gf := gb.getAggRef(aggName)
if gf != nil {
// if we've already computed use reference here
return gf
}

col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
id := gb.outScope.newColumn(col)
col.id = id

agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
col.scalar = agg

gb.addAggStr(col)
return col.scalarGf()
return b.buildCountStarAggregate(e, gb)
}
}

if strings.EqualFold(name, "jsonarray") {
// TODO we don't have any tests for this
if _, ok := e.Exprs[0].(*ast.StarExpr); ok {
var agg sql.Aggregation
agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64))
b.qFlags.Set(sql.QFlagStar)

//if e.Distinct {
// agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
//}
aggName := strings.ToLower(agg.String())
gf := gb.getAggRef(aggName)
if gf != nil {
// if we've already computed use reference here
return gf
}

col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
id := gb.outScope.newColumn(col)

agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray)
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
col.scalar = agg

col.id = id
gb.addAggStr(col)
return col.scalarGf()
return b.buildJsonArrayStarAggregate(gb)
}
}

if strings.EqualFold(name, "any_value") {
b.qFlags.Set(sql.QFlagAnyAgg)
}

var args []sql.Expression
for _, arg := range e.Exprs {
e := b.selectExprToExpression(inScope, arg)
switch e := e.(type) {
case *expression.GetField:
if e.TableId() == 0 {
// TODO: not sure where this came from but it's not true
// aliases are not valid aggregate arguments, the alias must be masking a column
gf := b.selectExprToExpression(inScope.parent, arg)
var ok bool
e, ok = gf.(*expression.GetField)
if !ok || e.TableId() == 0 {
b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf))
}
}
args = append(args, e)
col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()}
gb.addInCol(col)
case *expression.Star:
err := sql.ErrStarUnsupported.New()
b.handleErr(err)
case *plan.Subquery:
args = append(args, e)
col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()}
gb.addInCol(col)
default:
args = append(args, e)
col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()}
gb.addInCol(col)
}
args := b.buildAggFunctionArgs(inScope, e, gb)
agg := b.newAggregation(e, name, args)

if name == "count" {
b.qFlags.Set(sql.QFlagCount)
}

aggType := agg.Type()
if name == "avg" || name == "sum" {
aggType = types.Float64
}

aggName := strings.ToLower(plan.AliasSubqueryString(agg))
if id, ok := gb.outScope.getExpr(aggName, true); ok {
// if we've already computed use reference here
gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable())
return gf
}

col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()}
id := gb.outScope.newColumn(col)

agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
col.scalar = agg

col.id = id
gb.addAggStr(col)
return col.scalarGf()
}

// newAggregation creates a new aggregation function instanc from the arguments given
func (b *Builder) newAggregation(e *ast.FuncExpr, name string, args []sql.Expression) sql.Aggregation {
var agg sql.Aggregation
if e.Distinct && name == "count" {
agg = aggregation.NewCountDistinct(args...)
} else {

// NOTE: Not all aggregate functions support DISTINCT. Fortunately, the vitess parser will throw
// errors for when DISTINCT is used on aggregate functions that don't support DISTINCT.
if e.Distinct {
Expand Down Expand Up @@ -412,39 +364,104 @@ func (b *Builder) buildAggregateFunc(inScope *scope, name string, e *ast.FuncExp
b.handleErr(err)
}
}
return agg
}

if name == "count" {
b.qFlags.Set(sql.QFlagCount)
// buildAggFunctionArgs builds the arguments for an aggregate function
func (b *Builder) buildAggFunctionArgs(inScope *scope, e *ast.FuncExpr, gb *groupBy) []sql.Expression {
var args []sql.Expression
for _, arg := range e.Exprs {
e := b.selectExprToExpression(inScope, arg)
switch e := e.(type) {
case *expression.GetField:
if e.TableId() == 0 {
// TODO: not sure where this came from but it's not true
// aliases are not valid aggregate arguments, the alias must be masking a column
gf := b.selectExprToExpression(inScope.parent, arg)
var ok bool
e, ok = gf.(*expression.GetField)
if !ok || e.TableId() == 0 {
b.handleErr(fmt.Errorf("failed to resolve aggregate column argument: %s", gf))
}
}
args = append(args, e)
col := scopeColumn{tableId: e.TableID(), db: e.Database(), table: e.Table(), col: e.Name(), scalar: e, typ: e.Type(), nullable: e.IsNullable()}
gb.addInCol(col)
case *expression.Star:
err := sql.ErrStarUnsupported.New()
b.handleErr(err)
case *plan.Subquery:
args = append(args, e)
col := scopeColumn{col: e.QueryString, scalar: e, typ: e.Type()}
gb.addInCol(col)
default:
args = append(args, e)
col := scopeColumn{col: e.String(), scalar: e, typ: e.Type()}
gb.addInCol(col)
}
}
return args
}

aggType := agg.Type()
if name == "avg" || name == "sum" {
aggType = types.Float64
// buildJsonArrayStarAggregate builds a JSON_ARRAY(*) aggregate function
func (b *Builder) buildJsonArrayStarAggregate(gb *groupBy) sql.Expression {
var agg sql.Aggregation
agg = aggregation.NewJsonArray(expression.NewLiteral(expression.NewStar(), types.Int64))
b.qFlags.Set(sql.QFlagStar)

// if e.Distinct {
// agg = plan.NewDistinct(expression.NewLiteral(1, types.Int64))
// }
aggName := strings.ToLower(agg.String())
gf := gb.getAggRef(aggName)
if gf != nil {
// if we've already computed use reference here
return gf
}

aggName := strings.ToLower(plan.AliasSubqueryString(agg))
if id, ok := gb.outScope.getExpr(aggName, true); ok {
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
id := gb.outScope.newColumn(col)

agg = agg.WithId(sql.ColumnId(id)).(*aggregation.JsonArray)
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
col.scalar = agg

col.id = id
gb.addAggStr(col)
return col.scalarGf()
}

// buildCountStarAggregate builds a COUNT(*) aggregate function
func (b *Builder) buildCountStarAggregate(e *ast.FuncExpr, gb *groupBy) sql.Expression {
var agg sql.Aggregation
if e.Distinct {
agg = aggregation.NewCountDistinct(expression.NewLiteral(1, types.Int64))
} else {
agg = aggregation.NewCount(expression.NewLiteral(1, types.Int64))
}
b.qFlags.Set(sql.QFlagCountStar)
aggName := strings.ToLower(agg.String())
gf := gb.getAggRef(aggName)
if gf != nil {
// if we've already computed use reference here
gf := expression.NewGetFieldWithTable(int(id), 0, aggType, "", "", aggName, agg.IsNullable())
return gf
}

col := scopeColumn{col: aggName, scalar: agg, typ: aggType, nullable: agg.IsNullable()}
col := scopeColumn{col: strings.ToLower(agg.String()), scalar: agg, typ: agg.Type(), nullable: agg.IsNullable()}
id := gb.outScope.newColumn(col)
col.id = id

agg = agg.WithId(sql.ColumnId(id)).(sql.Aggregation)
gb.outScope.cols[len(gb.outScope.cols)-1].scalar = agg
col.scalar = agg

col.id = id
gb.addAggStr(col)
return col.scalarGf()
}

// buildGroupConcat builds a GROUP_CONCAT aggregate function
func (b *Builder) buildGroupConcat(inScope *scope, e *ast.GroupConcatExpr) sql.Expression {
if inScope.groupBy == nil {
inScope.initGroupBy()
}
inScope.initGroupBy()
gb := inScope.groupBy

args := make([]sql.Expression, len(e.Exprs))
Expand Down Expand Up @@ -794,7 +811,7 @@ func (b *Builder) analyzeHaving(fromScope, projScope *scope, having *ast.Where)
return false, nil
case *ast.FuncExpr:
name := n.Name.Lowered()
if isAggregateFunc(name) {
if IsAggregateFunc(name) {
// record aggregate
// TODO: this should get projScope as well
_ = b.buildAggregateFunc(fromScope, name, n)
Expand Down Expand Up @@ -874,9 +891,7 @@ func (b *Builder) buildHaving(fromScope, projScope, outScope *scope, having *ast
if having == nil {
return
}
if fromScope.groupBy == nil {
fromScope.initGroupBy()
}
fromScope.initGroupBy()

havingScope := b.newScope()
if fromScope.parent != nil {
Expand Down
2 changes: 1 addition & 1 deletion sql/planbuilder/scalar.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (b *Builder) buildScalar(inScope *scope, e ast.Expr) (ex sql.Expression) {
return b.buildNameConst(inScope, v)
} else if name == "icu_version" {
return expression.NewLiteral(icuVersion, types.MustCreateString(query.Type_VARCHAR, int64(len(icuVersion)), sql.Collation_Default))
} else if isAggregateFunc(name) && v.Over == nil {
} else if IsAggregateFunc(name) && v.Over == nil {
// TODO this assumes aggregate is in the same scope
// also need to avoid nested aggregates
return b.buildAggregateFunc(inScope, name, v)
Expand Down
4 changes: 3 additions & 1 deletion sql/planbuilder/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ func (s *scope) initProc() {
// initGroupBy creates a container scope for aggregation
// functions and function inputs.
func (s *scope) initGroupBy() {
s.groupBy = &groupBy{outScope: s.replace()}
if s.groupBy == nil {
s.groupBy = &groupBy{outScope: s.replace()}
}
}

// pushSubquery creates a new scope with the subquery already initialized.
Expand Down
2 changes: 1 addition & 1 deletion sql/planbuilder/show.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ func (b *Builder) buildAsOfExpr(inScope *scope, time ast.Expr) sql.Expression {
return expression.NewLiteral(v.String(), types.LongText)
case *ast.FuncExpr:
// todo(max): more specific validation for nested ASOF functions
if isWindowFunc(v.Name.Lowered()) || isAggregateFunc(v.Name.Lowered()) {
if isWindowFunc(v.Name.Lowered()) || IsAggregateFunc(v.Name.Lowered()) {
err := sql.ErrInvalidAsOfExpression.New(v)
b.handleErr(err)
}
Expand Down