Skip to content

Commit 372b621

Browse files
authored
fix aggregate expressions with aliases, fixes src-d#112. (src-d#116)
1 parent ec4a396 commit 372b621

File tree

3 files changed

+42
-21
lines changed

3 files changed

+42
-21
lines changed

engine_test.go

+16-14
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ func TestEngine_Query(t *testing.T) {
4747
"SELECT COUNT(*) FROM mytable;",
4848
[][]interface{}{{int64(3)}},
4949
)
50+
51+
testQuery(t, e,
52+
"SELECT COUNT(*) AS c FROM mytable;",
53+
[][]interface{}{{int64(3)}},
54+
)
5055
}
5156

5257
func testQuery(t *testing.T, e *gitql.Engine, q string, r [][]interface{}) {
@@ -64,29 +69,26 @@ func testQuery(t *testing.T, e *gitql.Engine, q string, r [][]interface{}) {
6469
assert.NoError(err)
6570
assert.Equal(len(r[0]), len(cols))
6671

72+
vals := make([]interface{}, len(cols))
73+
valPtrs := make([]interface{}, len(cols))
74+
for i := 0; i < len(cols); i++ {
75+
valPtrs[i] = &vals[i]
76+
}
77+
6778
i := 0
6879
for {
6980
if !res.Next() {
7081
break
7182
}
7283

73-
expectedRow := r[i]
74-
i++
75-
76-
row := make([]interface{}, len(expectedRow))
77-
for i := range row {
78-
i64 := int64(0)
79-
row[i] = &i64
80-
}
84+
err := res.Scan(valPtrs...)
85+
assert.NoError(err)
8186

82-
assert.NoError(res.Scan(row...))
83-
for i := range row {
84-
row[i] = *(row[i].(*int64))
85-
}
86-
87-
assert.Equal(expectedRow, row)
87+
assert.Equal(r[i], vals)
88+
i++
8889
}
8990

91+
assert.NoError(res.Err())
9092
assert.Equal(len(r), i)
9193
}
9294

sql/parse/parse.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,17 @@ func limitToLimit(o sqlparser.Expr, child sql.Node) (*plan.Limit, error) {
190190
return plan.NewLimit(n, child), nil
191191
}
192192

193+
func isAggregate(e sql.Expression) bool {
194+
switch v := e.(type) {
195+
case *expression.UnresolvedFunction:
196+
return v.IsAggregate
197+
case *expression.Alias:
198+
return isAggregate(v.Child)
199+
default:
200+
return false
201+
}
202+
}
203+
193204
func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, child sql.Node) (sql.Node, error) {
194205
selectExprs, err := selectExprsToExpressions(se)
195206
if err != nil {
@@ -199,8 +210,9 @@ func selectToProjectOrGroupBy(se sqlparser.SelectExprs, g sqlparser.GroupBy, chi
199210
isAgg := len(g) > 0
200211
if !isAgg {
201212
for _, e := range selectExprs {
202-
if u, ok := e.(*expression.UnresolvedFunction); ok {
203-
isAgg = u.IsAggregate
213+
if isAggregate(e) {
214+
isAgg = true
215+
break
204216
}
205217
}
206218
}

sql/plan/group_by.go

+12-5
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,19 @@ func aggregate(exprs []sql.Expression, rows []sql.Row) sql.Row {
183183
func exprsToAggregateExprs(exprs []sql.Expression) []sql.AggregationExpression {
184184
var r []sql.AggregationExpression
185185
for _, e := range exprs {
186-
if ae, ok := e.(sql.AggregationExpression); ok {
187-
r = append(r, ae)
188-
} else {
189-
r = append(r, expression.NewFirst(e))
190-
}
186+
r = append(r, exprToAggregateExpr(e))
191187
}
192188

193189
return r
194190
}
191+
192+
func exprToAggregateExpr(e sql.Expression) sql.AggregationExpression {
193+
switch v := e.(type) {
194+
case sql.AggregationExpression:
195+
return v
196+
case *expression.Alias:
197+
return exprToAggregateExpr(v.Child)
198+
default:
199+
return expression.NewFirst(e)
200+
}
201+
}

0 commit comments

Comments
 (0)