Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit b829206

Browse files
authored
Merge pull request #633 from erizocosmico/fix/reorder-error
allow all expressions in grouping, resolve orderby expressions
2 parents 08e98ce + 52476d6 commit b829206

File tree

4 files changed

+44
-21
lines changed

4 files changed

+44
-21
lines changed

Diff for: engine_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,14 @@ var queries = []struct {
868868
"ROLLBACK",
869869
[]sql.Row{},
870870
},
871+
{
872+
"SELECT substring(s, 1, 1) FROM mytable ORDER BY substring(s, 1, 1)",
873+
[]sql.Row{{"f"}, {"s"}, {"t"}},
874+
},
875+
{
876+
"SELECT substring(s, 1, 1), count(*) FROM mytable GROUP BY substring(s, 1, 1)",
877+
[]sql.Row{{"f", int32(1)}, {"s", int32(1)}, {"t", int32(1)}},
878+
},
871879
}
872880

873881
func TestQueries(t *testing.T) {

Diff for: sql/analyzer/resolve_orderby.go

+20-8
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,14 @@ func resolveOrderBy(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node, error)
3535
var colsFromChild []string
3636
var missingCols []string
3737
for _, f := range sort.SortFields {
38-
n, ok := f.Column.(sql.Nameable)
39-
if !ok {
40-
continue
41-
}
38+
ns := findExprNameables(f.Column)
4239

43-
if stringContains(childNewCols, n.Name()) {
44-
colsFromChild = append(colsFromChild, n.Name())
45-
} else if !stringContains(schemaCols, n.Name()) {
46-
missingCols = append(missingCols, n.Name())
40+
for _, n := range ns {
41+
if stringContains(childNewCols, n.Name()) {
42+
colsFromChild = append(colsFromChild, n.Name())
43+
} else if !stringContains(schemaCols, n.Name()) {
44+
missingCols = append(missingCols, n.Name())
45+
}
4746
}
4847
}
4948

@@ -221,3 +220,16 @@ func resolveOrderByLiterals(ctx *sql.Context, a *Analyzer, n sql.Node) (sql.Node
221220
return plan.NewSort(fields, sort.Child), nil
222221
})
223222
}
223+
224+
func findExprNameables(e sql.Expression) []sql.Nameable {
225+
var result []sql.Nameable
226+
expression.Inspect(e, func(e sql.Expression) bool {
227+
n, ok := e.(sql.Nameable)
228+
if ok {
229+
result = append(result, n)
230+
return false
231+
}
232+
return true
233+
})
234+
return result
235+
}

Diff for: sql/plan/group_by.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -353,15 +353,13 @@ func updateBuffer(
353353
return n.Update(ctx, buffers[idx], row)
354354
case *expression.Alias:
355355
return updateBuffer(ctx, buffers, idx, n.Child, row)
356-
case *expression.GetField:
356+
default:
357357
val, err := expr.Eval(ctx, row)
358358
if err != nil {
359359
return err
360360
}
361361
buffers[idx] = sql.NewRow(val)
362362
return nil
363-
default:
364-
return ErrGroupBy.New(n.String())
365363
}
366364
}
367365

@@ -393,12 +391,10 @@ func evalBuffer(
393391
return n.Eval(ctx, buffer)
394392
case *expression.Alias:
395393
return evalBuffer(ctx, n.Child, buffer)
396-
case *expression.GetField:
394+
default:
397395
if len(buffer) > 0 {
398396
return buffer[0], nil
399397
}
400398
return nil, nil
401-
default:
402-
return nil, ErrGroupBy.New(n.String())
403399
}
404400
}

Diff for: sql/plan/group_by_test.go

+14-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import (
1010
"gopkg.in/src-d/go-mysql-server.v0/sql/expression/function/aggregation"
1111
)
1212

13-
func TestGroupBy_Schema(t *testing.T) {
13+
func TestGroupBySchema(t *testing.T) {
1414
require := require.New(t)
1515

1616
child := mem.NewTable("test", nil)
@@ -25,7 +25,7 @@ func TestGroupBy_Schema(t *testing.T) {
2525
}, gb.Schema())
2626
}
2727

28-
func TestGroupBy_Resolved(t *testing.T) {
28+
func TestGroupByResolved(t *testing.T) {
2929
require := require.New(t)
3030

3131
child := mem.NewTable("test", nil)
@@ -42,7 +42,7 @@ func TestGroupBy_Resolved(t *testing.T) {
4242
require.False(gb.Resolved())
4343
}
4444

45-
func TestGroupBy_RowIter(t *testing.T) {
45+
func TestGroupByRowIter(t *testing.T) {
4646
require := require.New(t)
4747
ctx := sql.NewEmptyContext()
4848

@@ -96,7 +96,7 @@ func TestGroupBy_RowIter(t *testing.T) {
9696
require.Equal(sql.NewRow("col1_2", int64(4444)), rows[1])
9797
}
9898

99-
func TestGroupBy_EvalEmptyBuffer(t *testing.T) {
99+
func TestGroupByEvalEmptyBuffer(t *testing.T) {
100100
require := require.New(t)
101101
ctx := sql.NewEmptyContext()
102102

@@ -105,7 +105,7 @@ func TestGroupBy_EvalEmptyBuffer(t *testing.T) {
105105
require.Nil(r)
106106
}
107107

108-
func TestGroupBy_Error(t *testing.T) {
108+
func TestGroupByAggregationGrouping(t *testing.T) {
109109
require := require.New(t)
110110
ctx := sql.NewEmptyContext()
111111

@@ -140,8 +140,15 @@ func TestGroupBy_Error(t *testing.T) {
140140
NewResolvedTable(child),
141141
)
142142

143-
_, err := sql.NodeToRows(ctx, p)
144-
require.Error(err)
143+
rows, err := sql.NodeToRows(ctx, p)
144+
require.NoError(err)
145+
146+
expected := []sql.Row{
147+
{int32(3), false},
148+
{int32(2), false},
149+
}
150+
151+
require.Equal(expected, rows)
145152
}
146153

147154
func BenchmarkGroupBy(b *testing.B) {

0 commit comments

Comments
 (0)