Skip to content

Commit 14cd44c

Browse files
authored
sql: add GROUP BY support. Closes #52. (#86)
* sql: add AggregationExpression interface. * sql: add function registry to Catalog. * sql/expression: add Count and First implementations. * sql/plan: add GroupBy node.
1 parent a2cbdbc commit 14cd44c

20 files changed

+957
-43
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ gitql supports a subset of the SQL standard, currently including:
5757
* `WHERE`
5858
* `ORDER BY` (with `ASC` and `DESC`)
5959
* `LIMIT`
60+
* `GROUP BY` (with `COUNT` and `FIRST`)
6061
* `SHOW TABLES`
6162
* `DESCRIBE TABLE`
6263

engine.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"github.com/gitql/gitql/sql"
55
"github.com/gitql/gitql/sql/analyzer"
66
"github.com/gitql/gitql/sql/parse"
7+
"github.com/gitql/gitql/sql/expression"
78
)
89

910
type Engine struct {
@@ -12,7 +13,12 @@ type Engine struct {
1213
}
1314

1415
func New() *Engine {
15-
c := &sql.Catalog{}
16+
c := sql.NewCatalog()
17+
err := expression.RegisterDefaults(c)
18+
if err != nil {
19+
panic(err)
20+
}
21+
1622
a := analyzer.New(c)
1723
return &Engine{c, a}
1824
}

engine_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ func TestEngine_Query(t *testing.T) {
5151
sql.NewMemoryRow(int64(1)),
5252
},
5353
)
54+
55+
testQuery(t, e,
56+
"SELECT COUNT(*) FROM mytable;",
57+
[]sql.Row{
58+
sql.NewMemoryRow(int32(3)),
59+
},
60+
)
5461
}
5562

5663
func testQuery(t *testing.T, e *gitql.Engine, q string, r []sql.Row) {

sql/analyzer/rules.go

+27
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ var DefaultRules = []Rule{
1111
{"resolve_columns", resolveColumns},
1212
{"resolve_database", resolveDatabase},
1313
{"resolve_star", resolveStar},
14+
{"resolve_functions", resolveFunctions},
1415
}
1516

1617
func resolveDatabase(a *Analyzer, n sql.Node) sql.Node {
@@ -102,3 +103,29 @@ func resolveColumns(a *Analyzer, n sql.Node) sql.Node {
102103
return gf
103104
})
104105
}
106+
107+
func resolveFunctions(a *Analyzer, n sql.Node) sql.Node {
108+
if n.Resolved() {
109+
return n
110+
}
111+
112+
return n.TransformExpressionsUp(func(e sql.Expression) sql.Expression {
113+
uf, ok := e.(*expression.UnresolvedFunction)
114+
if !ok {
115+
return e
116+
}
117+
118+
n := uf.Name()
119+
f, err := a.Catalog.Function(n)
120+
if err != nil {
121+
return e
122+
}
123+
124+
rf, err := f.Build(uf.Children...)
125+
if err != nil {
126+
return e
127+
}
128+
129+
return rf
130+
})
131+
}

sql/catalog.go

+107
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
11
package sql
22

33
import (
4+
"errors"
45
"fmt"
6+
"reflect"
57
)
68

79
type Catalog struct {
810
Databases []Database
11+
Functions map[string]*FunctionEntry
12+
}
13+
14+
func NewCatalog() *Catalog {
15+
return &Catalog{
16+
Functions: map[string]*FunctionEntry{},
17+
}
918
}
1019

1120
func (c Catalog) Database(name string) (Database, error) {
@@ -32,3 +41,101 @@ func (c Catalog) Table(dbName string, tableName string) (Table, error) {
3241

3342
return table, nil
3443
}
44+
45+
func (c Catalog) RegisterFunction(name string, f interface{}) error {
46+
e, err := inspectFunction(f)
47+
if err != nil {
48+
return err
49+
}
50+
51+
c.Functions[name] = e
52+
return nil
53+
}
54+
55+
func (c Catalog) Function(name string) (*FunctionEntry, error) {
56+
e, ok := c.Functions[name]
57+
if !ok {
58+
return nil, fmt.Errorf("function not found: %s", name)
59+
}
60+
61+
return e, nil
62+
}
63+
64+
type FunctionEntry struct {
65+
v reflect.Value
66+
}
67+
68+
func (e *FunctionEntry) Build(args ...Expression) (Expression, error) {
69+
t := e.v.Type()
70+
if !t.IsVariadic() && len(args) != t.NumIn() {
71+
return nil, fmt.Errorf("expected %d args, got %d",
72+
t.NumIn(), len(args))
73+
}
74+
75+
if t.IsVariadic() && len(args) < t.NumIn()-1 {
76+
return nil, fmt.Errorf("expected at least %d args, got %d",
77+
t.NumIn(), len(args))
78+
}
79+
80+
var in []reflect.Value
81+
for _, arg := range args {
82+
in = append(in, reflect.ValueOf(arg))
83+
}
84+
85+
out := e.v.Call(in)
86+
if len(out) != 1 {
87+
return nil, fmt.Errorf("expected 1 return value, got %d: ", len(out))
88+
}
89+
90+
expr, ok := out[0].Interface().(Expression)
91+
if !ok {
92+
return nil, errors.New("return value doesn't implement Expression")
93+
}
94+
95+
return expr, nil
96+
}
97+
98+
var (
99+
expressionType = buildExpressionType()
100+
expressionSliceType = buildExpressionSliceType()
101+
)
102+
103+
func buildExpressionType() reflect.Type {
104+
var v Expression
105+
return reflect.ValueOf(&v).Elem().Type()
106+
}
107+
108+
func buildExpressionSliceType() reflect.Type {
109+
var v []Expression
110+
return reflect.ValueOf(&v).Elem().Type()
111+
}
112+
113+
func inspectFunction(f interface{}) (*FunctionEntry, error) {
114+
v := reflect.ValueOf(f)
115+
t := v.Type()
116+
if t.Kind() != reflect.Func {
117+
return nil, fmt.Errorf("expected function, got: %s", t.Kind())
118+
}
119+
120+
if t.NumOut() != 1 {
121+
return nil, errors.New("function builders must return a single Expression")
122+
}
123+
124+
out := t.Out(0)
125+
if !out.Implements(expressionType) {
126+
return nil, fmt.Errorf("return value doesn't implement Expression: %s", out)
127+
}
128+
129+
for i := 0; i < t.NumIn(); i++ {
130+
in := t.In(i)
131+
if i == t.NumIn()-1 && t.IsVariadic() && in == expressionSliceType {
132+
continue
133+
}
134+
135+
if in != expressionType {
136+
return nil, fmt.Errorf("input argument %d is not a Expression", i)
137+
}
138+
}
139+
140+
return &FunctionEntry{v}, nil
141+
}

sql/catalog_test.go

+144-2
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ import (
66
"github.com/gitql/gitql/mem"
77
"github.com/gitql/gitql/sql"
88

9+
"github.com/gitql/gitql/sql/expression"
910
"github.com/stretchr/testify/assert"
1011
)
1112

1213
func TestCatalog_Database(t *testing.T) {
1314
assert := assert.New(t)
1415

15-
c := sql.Catalog{}
16+
c := sql.NewCatalog()
1617
db, err := c.Database("foo")
1718
assert.EqualError(err, "database not found: foo")
1819
assert.Nil(db)
@@ -28,7 +29,7 @@ func TestCatalog_Database(t *testing.T) {
2829
func TestCatalog_Table(t *testing.T) {
2930
assert := assert.New(t)
3031

31-
c := sql.Catalog{}
32+
c := sql.NewCatalog()
3233

3334
table, err := c.Table("foo", "bar")
3435
assert.EqualError(err, "database not found: foo")
@@ -48,3 +49,144 @@ func TestCatalog_Table(t *testing.T) {
4849
assert.NoError(err)
4950
assert.Equal(mytable, table)
5051
}
52+
53+
func TestCatalog_RegisterFunction_NoArgs(t *testing.T) {
54+
assert := assert.New(t)
55+
56+
c := sql.NewCatalog()
57+
name := "func"
58+
var expected sql.Expression = expression.NewStar()
59+
err := c.RegisterFunction(name, func() sql.Expression {
60+
return expected
61+
})
62+
assert.Nil(err)
63+
64+
f, err := c.Function(name)
65+
assert.Nil(err)
66+
67+
e, err := f.Build()
68+
assert.Nil(err)
69+
assert.Equal(expected, e)
70+
71+
e, err = f.Build(expression.NewStar())
72+
assert.NotNil(err)
73+
assert.Nil(e)
74+
75+
e, err = f.Build(expression.NewStar(), expression.NewStar())
76+
assert.NotNil(err)
77+
assert.Nil(e)
78+
}
79+
80+
func TestCatalog_RegisterFunction_OneArg(t *testing.T) {
81+
assert := assert.New(t)
82+
83+
c := sql.NewCatalog()
84+
name := "func"
85+
var expected sql.Expression = expression.NewStar()
86+
err := c.RegisterFunction(name, func(sql.Expression) sql.Expression {
87+
return expected
88+
})
89+
assert.Nil(err)
90+
91+
f, err := c.Function(name)
92+
assert.Nil(err)
93+
94+
e, err := f.Build()
95+
assert.NotNil(err)
96+
assert.Nil(e)
97+
98+
e, err = f.Build(expression.NewStar())
99+
assert.Nil(err)
100+
assert.Equal(expected, e)
101+
102+
e, err = f.Build(expression.NewStar(), expression.NewStar())
103+
assert.NotNil(err)
104+
assert.Nil(e)
105+
}
106+
107+
func TestCatalog_RegisterFunction_Variadic(t *testing.T) {
108+
assert := assert.New(t)
109+
110+
c := sql.NewCatalog()
111+
name := "func"
112+
var expected sql.Expression = expression.NewStar()
113+
err := c.RegisterFunction(name, func(...sql.Expression) sql.Expression {
114+
return expected
115+
})
116+
assert.Nil(err)
117+
118+
f, err := c.Function(name)
119+
assert.Nil(err)
120+
121+
e, err := f.Build()
122+
assert.Nil(err)
123+
assert.Equal(expected, e)
124+
125+
e, err = f.Build(expression.NewStar())
126+
assert.Nil(err)
127+
assert.Equal(expected, e)
128+
129+
e, err = f.Build(expression.NewStar(), expression.NewStar())
130+
assert.Nil(err)
131+
assert.Equal(expected, e)
132+
}
133+
134+
func TestCatalog_RegisterFunction_OneAndVariadic(t *testing.T) {
135+
assert := assert.New(t)
136+
137+
c := sql.NewCatalog()
138+
name := "func"
139+
var expected sql.Expression = expression.NewStar()
140+
err := c.RegisterFunction(name, func(sql.Expression, ...sql.Expression) sql.Expression {
141+
return expected
142+
})
143+
assert.Nil(err)
144+
145+
f, err := c.Function(name)
146+
assert.Nil(err)
147+
148+
e, err := f.Build()
149+
assert.NotNil(err)
150+
assert.Nil(e)
151+
152+
e, err = f.Build(expression.NewStar())
153+
assert.Nil(err)
154+
assert.Equal(expected, e)
155+
156+
e, err = f.Build(expression.NewStar(), expression.NewStar())
157+
assert.Nil(err)
158+
assert.Equal(expected, e)
159+
}
160+
161+
func TestCatalog_RegisterFunction_Invalid(t *testing.T) {
162+
assert := assert.New(t)
163+
164+
c := sql.NewCatalog()
165+
name := "func"
166+
err := c.RegisterFunction(name, func(sql.Table) sql.Expression {
167+
return nil
168+
})
169+
assert.NotNil(err)
170+
171+
err = c.RegisterFunction(name, func(sql.Expression) sql.Table {
172+
return nil
173+
})
174+
assert.NotNil(err)
175+
176+
err = c.RegisterFunction(name, func(sql.Expression) (sql.Table, error) {
177+
return nil, nil
178+
})
179+
assert.NotNil(err)
180+
181+
err = c.RegisterFunction(name, 1)
182+
assert.NotNil(err)
183+
}
184+
185+
func TestCatalog_Function_NotExists(t *testing.T) {
186+
assert := assert.New(t)
187+
188+
c := sql.NewCatalog()
189+
f, err := c.Function("func")
190+
assert.NotNil(err)
191+
assert.Nil(f)
192+
}

0 commit comments

Comments
 (0)