Skip to content

sql: add GROUP BY support. Closes #52. #86

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 23, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ gitql supports a subset of the SQL standard, currently including:
* `WHERE`
* `ORDER BY` (with `ASC` and `DESC`)
* `LIMIT`
* `GROUP BY` (with `COUNT` and `FIRST`)
* `SHOW TABLES`
* `DESCRIBE TABLE`

Expand Down
8 changes: 7 additions & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"github.com/gitql/gitql/sql"
"github.com/gitql/gitql/sql/analyzer"
"github.com/gitql/gitql/sql/parse"
"github.com/gitql/gitql/sql/expression"
)

type Engine struct {
Expand All @@ -12,7 +13,12 @@ type Engine struct {
}

func New() *Engine {
c := &sql.Catalog{}
c := sql.NewCatalog()
err := expression.RegisterDefaults(c)
if err != nil {
panic(err)
}

a := analyzer.New(c)
return &Engine{c, a}
}
Expand Down
7 changes: 7 additions & 0 deletions engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ func TestEngine_Query(t *testing.T) {
sql.NewMemoryRow(int64(1)),
},
)

testQuery(t, e,
"SELECT COUNT(*) FROM mytable;",
[]sql.Row{
sql.NewMemoryRow(int32(3)),
},
)
}

func testQuery(t *testing.T, e *gitql.Engine, q string, r []sql.Row) {
Expand Down
27 changes: 27 additions & 0 deletions sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ var DefaultRules = []Rule{
{"resolve_columns", resolveColumns},
{"resolve_database", resolveDatabase},
{"resolve_star", resolveStar},
{"resolve_functions", resolveFunctions},
}

func resolveDatabase(a *Analyzer, n sql.Node) sql.Node {
Expand Down Expand Up @@ -102,3 +103,29 @@ func resolveColumns(a *Analyzer, n sql.Node) sql.Node {
return gf
})
}

func resolveFunctions(a *Analyzer, n sql.Node) sql.Node {
if n.Resolved() {
return n
}

return n.TransformExpressionsUp(func(e sql.Expression) sql.Expression {
uf, ok := e.(*expression.UnresolvedFunction)
if !ok {
return e
}

n := uf.Name()
f, err := a.Catalog.Function(n)
if err != nil {
return e
}

rf, err := f.Build(uf.Children...)
if err != nil {
return e
}

return rf
})
}
107 changes: 107 additions & 0 deletions sql/catalog.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
package sql

import (
"errors"
"fmt"
"reflect"
)

type Catalog struct {
Databases []Database
Functions map[string]*FunctionEntry
}

func NewCatalog() *Catalog {
return &Catalog{
Functions: map[string]*FunctionEntry{},
}
}

func (c Catalog) Database(name string) (Database, error) {
Expand All @@ -32,3 +41,101 @@ func (c Catalog) Table(dbName string, tableName string) (Table, error) {

return table, nil
}

func (c Catalog) RegisterFunction(name string, f interface{}) error {
e, err := inspectFunction(f)
if err != nil {
return err
}

c.Functions[name] = e
return nil
}

func (c Catalog) Function(name string) (*FunctionEntry, error) {
e, ok := c.Functions[name]
if !ok {
return nil, fmt.Errorf("function not found: %s", name)
}

return e, nil
}

type FunctionEntry struct {
v reflect.Value
}

func (e *FunctionEntry) Build(args ...Expression) (Expression, error) {
t := e.v.Type()
if !t.IsVariadic() && len(args) != t.NumIn() {
return nil, fmt.Errorf("expected %d args, got %d",
t.NumIn(), len(args))
}

if t.IsVariadic() && len(args) < t.NumIn()-1 {
return nil, fmt.Errorf("expected at least %d args, got %d",
t.NumIn(), len(args))
}

var in []reflect.Value
for _, arg := range args {
in = append(in, reflect.ValueOf(arg))
}

out := e.v.Call(in)
if len(out) != 1 {
return nil, fmt.Errorf("expected 1 return value, got %d: ", len(out))
}

expr, ok := out[0].Interface().(Expression)
if !ok {
return nil, errors.New("return value doesn't implement Expression")
}

return expr, nil
}

var (
expressionType = buildExpressionType()
expressionSliceType = buildExpressionSliceType()
)

func buildExpressionType() reflect.Type {
var v Expression
return reflect.ValueOf(&v).Elem().Type()
}

func buildExpressionSliceType() reflect.Type {
var v []Expression
return reflect.ValueOf(&v).Elem().Type()
}

func inspectFunction(f interface{}) (*FunctionEntry, error) {
v := reflect.ValueOf(f)
t := v.Type()
if t.Kind() != reflect.Func {
return nil, fmt.Errorf("expected function, got: %s", t.Kind())
}

if t.NumOut() != 1 {
return nil, errors.New("function builders must return a single Expression")
}

out := t.Out(0)
if !out.Implements(expressionType) {
return nil, fmt.Errorf("return value doesn't implement Expression: %s", out)
}

for i := 0; i < t.NumIn(); i++ {
in := t.In(i)
if i == t.NumIn()-1 && t.IsVariadic() && in == expressionSliceType {
continue
}

if in != expressionType {
return nil, fmt.Errorf("input argument %d is not a Expression", i)
}
}

return &FunctionEntry{v}, nil
}
146 changes: 144 additions & 2 deletions sql/catalog_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ import (
"github.com/gitql/gitql/mem"
"github.com/gitql/gitql/sql"

"github.com/gitql/gitql/sql/expression"
"github.com/stretchr/testify/assert"
)

func TestCatalog_Database(t *testing.T) {
assert := assert.New(t)

c := sql.Catalog{}
c := sql.NewCatalog()
db, err := c.Database("foo")
assert.EqualError(err, "database not found: foo")
assert.Nil(db)
Expand All @@ -28,7 +29,7 @@ func TestCatalog_Database(t *testing.T) {
func TestCatalog_Table(t *testing.T) {
assert := assert.New(t)

c := sql.Catalog{}
c := sql.NewCatalog()

table, err := c.Table("foo", "bar")
assert.EqualError(err, "database not found: foo")
Expand All @@ -48,3 +49,144 @@ func TestCatalog_Table(t *testing.T) {
assert.NoError(err)
assert.Equal(mytable, table)
}

func TestCatalog_RegisterFunction_NoArgs(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
var expected sql.Expression = expression.NewStar()
err := c.RegisterFunction(name, func() sql.Expression {
return expected
})
assert.Nil(err)

f, err := c.Function(name)
assert.Nil(err)

e, err := f.Build()
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar())
assert.NotNil(err)
assert.Nil(e)

e, err = f.Build(expression.NewStar(), expression.NewStar())
assert.NotNil(err)
assert.Nil(e)
}

func TestCatalog_RegisterFunction_OneArg(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
var expected sql.Expression = expression.NewStar()
err := c.RegisterFunction(name, func(sql.Expression) sql.Expression {
return expected
})
assert.Nil(err)

f, err := c.Function(name)
assert.Nil(err)

e, err := f.Build()
assert.NotNil(err)
assert.Nil(e)

e, err = f.Build(expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar(), expression.NewStar())
assert.NotNil(err)
assert.Nil(e)
}

func TestCatalog_RegisterFunction_Variadic(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
var expected sql.Expression = expression.NewStar()
err := c.RegisterFunction(name, func(...sql.Expression) sql.Expression {
return expected
})
assert.Nil(err)

f, err := c.Function(name)
assert.Nil(err)

e, err := f.Build()
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar(), expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)
}

func TestCatalog_RegisterFunction_OneAndVariadic(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
var expected sql.Expression = expression.NewStar()
err := c.RegisterFunction(name, func(sql.Expression, ...sql.Expression) sql.Expression {
return expected
})
assert.Nil(err)

f, err := c.Function(name)
assert.Nil(err)

e, err := f.Build()
assert.NotNil(err)
assert.Nil(e)

e, err = f.Build(expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)

e, err = f.Build(expression.NewStar(), expression.NewStar())
assert.Nil(err)
assert.Equal(expected, e)
}

func TestCatalog_RegisterFunction_Invalid(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
name := "func"
err := c.RegisterFunction(name, func(sql.Table) sql.Expression {
return nil
})
assert.NotNil(err)

err = c.RegisterFunction(name, func(sql.Expression) sql.Table {
return nil
})
assert.NotNil(err)

err = c.RegisterFunction(name, func(sql.Expression) (sql.Table, error) {
return nil, nil
})
assert.NotNil(err)

err = c.RegisterFunction(name, 1)
assert.NotNil(err)
}

func TestCatalog_Function_NotExists(t *testing.T) {
assert := assert.New(t)

c := sql.NewCatalog()
f, err := c.Function("func")
assert.NotNil(err)
assert.Nil(f)
}
Loading