From c5c04f8934a18d459ad82ffbc6e6758a11050b3c Mon Sep 17 00:00:00 2001 From: "Santiago M. Mola" Date: Fri, 13 Jan 2017 12:16:52 +0100 Subject: [PATCH] sql: implement database/sql/driver API. --- engine.go | 179 +++++++++++++++++++++++++++++++++++++++++++++++- engine_test.go | 78 +++++++++++---------- example_test.go | 66 ++++++++++++++++++ 3 files changed, 284 insertions(+), 39 deletions(-) create mode 100644 example_test.go diff --git a/engine.go b/engine.go index f6aaebb38..f776a93ff 100644 --- a/engine.go +++ b/engine.go @@ -1,17 +1,55 @@ package gitql import ( + gosql "database/sql" + "database/sql/driver" + "errors" + "fmt" + "github.com/gitql/gitql/sql" "github.com/gitql/gitql/sql/analyzer" "github.com/gitql/gitql/sql/expression" "github.com/gitql/gitql/sql/parse" ) +var ( + ErrNotSupported = errors.New("feature not supported yet") +) + +const ( + DriverName = "gitql" +) + +func init() { + gosql.Register(DriverName, defaultDriver) +} + +type drv struct{} + +var defaultDriver = &drv{} + +func (d *drv) Open(name string) (driver.Conn, error) { + if name != "" { + return nil, fmt.Errorf("data source not found: %s", name) + } + + e := DefaultEngine + return &session{Engine: e}, nil +} + +// DefaultEngine is the default Engine instance, used when opening a connection +// to gitql:// when using database/sql. +var DefaultEngine = New() + +// Engine is a SQL engine. +// It implements the standard database/sql/driver/Driver interface, so it can +// be registered as a database/sql driver. type Engine struct { Catalog *sql.Catalog Analyzer *analyzer.Analyzer } +// New creates a new Engine. func New() *Engine { c := sql.NewCatalog() err := expression.RegisterDefaults(c) @@ -23,11 +61,15 @@ func New() *Engine { return &Engine{c, a} } -func (e *Engine) AddDatabase(db sql.Database) { - e.Catalog.Databases = append(e.Catalog.Databases, db) - e.Analyzer.CurrentDatabase = db.Name() +// Open creates a new session for the engine and returns +// it as a driver.Conn. +// +// Name parameter is ignored. +func (e *Engine) Open(name string) (driver.Conn, error) { + return &session{Engine: e}, nil } +// Query executes a query without attaching to any session. func (e *Engine) Query(query string) (sql.Schema, sql.RowIter, error) { parsed, err := parse.Parse(query) if err != nil { @@ -46,3 +88,134 @@ func (e *Engine) Query(query string) (sql.Schema, sql.RowIter, error) { return analyzed.Schema(), iter, nil } + +func (e *Engine) AddDatabase(db sql.Database) { + e.Catalog.Databases = append(e.Catalog.Databases, db) + e.Analyzer.CurrentDatabase = db.Name() +} + +// Session represents a SQL session. +// It implements the standard database/sql/driver/Conn interface. +type session struct { + *Engine + closed bool + //TODO: Current database +} + +// Prepare returns a prepared statement, bound to this connection. +// Placeholders are not supported yet. +func (s *session) Prepare(query string) (driver.Stmt, error) { + if err := s.checkOpen(); err != nil { + return nil, err + } + + return &stmt{session: s, query: query}, nil +} + +// Close closes the session. +func (s *session) Close() error { + if err := s.checkOpen(); err != nil { + return err + } + + s.closed = true + return nil +} + +// Begin starts and returns a new transaction. +func (s *session) Begin() (driver.Tx, error) { + return nil, fmt.Errorf("transactions not supported") +} + +func (s *session) checkOpen() error { + if s.closed { + return driver.ErrBadConn + } + + return nil +} + +type stmt struct { + *session + query string + closed bool +} + +// Close closes the statement. +func (s *stmt) Close() error { + if err := s.checkOpen(); err != nil { + return err + } + + s.closed = true + return nil +} + +// NumInput returns the number of placeholder parameters. +// Always returns -1 since placeholders are not supported yet. +func (s *stmt) NumInput() int { + return -1 +} + +// Exec executes a query that doesn't return rows, such as an INSERT or UPDATE. +func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { + return nil, ErrNotSupported +} + +// Query executes a query that may return rows, such as a SELECT. +func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { + if len(args) > 0 { + return nil, ErrNotSupported + } + + schema, iter, err := s.session.Engine.Query(s.query) + if err != nil { + return nil, err + } + + return &rows{schema: schema, iter: iter}, nil +} + +func (s *stmt) checkOpen() error { + if s.closed { + return driver.ErrBadConn + } + + return nil +} + +type rows struct { + schema sql.Schema + iter sql.RowIter +} + +// Columns returns the names of the columns. +func (rs *rows) Columns() []string { + c := make([]string, len(rs.schema)) + for i := 0; i < len(rs.schema); i++ { + c[i] = rs.schema[i].Name + } + + return c +} + +// Close closes the rows iterator. +func (rs *rows) Close() error { + return rs.iter.Close() +} + +// Next populates the given array with the next row values. +// Returns io.EOF when there are no more values. +func (rs *rows) Next(dest []driver.Value) error { + r, err := rs.iter.Next() + if err != nil { + return err + } + + for i := range dest { + f := rs.schema[i] + dest[i] = f.Type.Native(r[i]) + } + + return nil +} diff --git a/engine_test.go b/engine_test.go index 71390abe7..3129a97be 100644 --- a/engine_test.go +++ b/engine_test.go @@ -1,7 +1,7 @@ package gitql_test import ( - "io" + gosql "database/sql" "testing" "github.com/gitql/gitql" @@ -11,77 +11,83 @@ import ( "github.com/stretchr/testify/require" ) +const ( + driverName = "engine_tests" +) + func TestEngine_Query(t *testing.T) { e := newEngine(t) + gosql.Register(driverName, e) testQuery(t, e, "SELECT i FROM mytable;", - []sql.Row{ - sql.NewRow(int64(1)), - sql.NewRow(int64(2)), - sql.NewRow(int64(3)), - }, + [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}}, ) testQuery(t, e, "SELECT i FROM mytable WHERE i = 2;", - []sql.Row{ - sql.NewRow(int64(2)), - }, + [][]interface{}{{int64(2)}}, ) testQuery(t, e, "SELECT i FROM mytable ORDER BY i DESC;", - []sql.Row{ - sql.NewRow(int64(3)), - sql.NewRow(int64(2)), - sql.NewRow(int64(1)), - }, + [][]interface{}{{int64(3)}, {int64(2)}, {int64(1)}}, ) testQuery(t, e, "SELECT i FROM mytable WHERE s = 'a' ORDER BY i DESC;", - []sql.Row{ - sql.NewRow(int64(1)), - }, + [][]interface{}{{int64(1)}}, ) testQuery(t, e, "SELECT i FROM mytable WHERE s = 'a' ORDER BY i DESC LIMIT 1;", - []sql.Row{ - sql.NewRow(int64(1)), - }, + [][]interface{}{{int64(1)}}, ) testQuery(t, e, "SELECT COUNT(*) FROM mytable;", - []sql.Row{ - sql.NewRow(int32(3)), - }, + [][]interface{}{{int64(3)}}, ) } -func testQuery(t *testing.T, e *gitql.Engine, q string, r []sql.Row) { +func testQuery(t *testing.T, e *gitql.Engine, q string, r [][]interface{}) { assert := require.New(t) - schema, iter, err := e.Query(q) - assert.Nil(err) - assert.NotNil(iter) - assert.NotNil(schema) + db, err := gosql.Open(driverName, "") + assert.NoError(err) + defer func() { assert.NoError(db.Close()) }() + + res, err := db.Query(q) + assert.NoError(err) + defer func() { assert.NoError(res.Close()) }() - results := []sql.Row{} + cols, err := res.Columns() + assert.NoError(err) + assert.Equal(len(r[0]), len(cols)) + + i := 0 for { - el, err := iter.Next() - if err == io.EOF { + if !res.Next() { break } - if err != nil { - assert.Fail("returned err distinct of io.EOF: %q", err) + + expectedRow := r[i] + i++ + + row := make([]interface{}, len(expectedRow)) + for i := range row { + i64 := int64(0) + row[i] = &i64 + } + + assert.NoError(res.Scan(row...)) + for i := range row { + row[i] = *(row[i].(*int64)) } - results = append(results, el) + + assert.Equal(expectedRow, row) } - assert.Len(results, len(r)) - assert.Equal(results, r) + assert.Equal(len(r), i) } func newEngine(t *testing.T) *gitql.Engine { diff --git a/example_test.go b/example_test.go new file mode 100644 index 000000000..deaa7bf9c --- /dev/null +++ b/example_test.go @@ -0,0 +1,66 @@ +package gitql_test + +import ( + "database/sql" + "fmt" + + "github.com/gitql/gitql" + "github.com/gitql/gitql/mem" + gitqlsql "github.com/gitql/gitql/sql" +) + +func Example() { + // Create a test memory database and register it to the default engine. + gitql.DefaultEngine.AddDatabase(createTestDatabase()) + + // Open a sql connection with the default engine. + conn, err := sql.Open(gitql.DriverName, "") + checkIfError(err) + + // Prepare a query. + stmt, err := conn.Prepare(`SELECT name, count(*) FROM mytable + WHERE name = 'John Doe' + GROUP BY name`) + checkIfError(err) + + // Get result rows. + rows, err := stmt.Query() + checkIfError(err) + + // Iterate results and print them. + for { + if !rows.Next() { + break + } + + name := "" + count := int64(0) + err := rows.Scan(&name, &count) + checkIfError(err) + + fmt.Println(name, count) + } + checkIfError(rows.Err()) + + // Output: John Doe 2 +} + +func checkIfError(err error) { + if err != nil { + panic(err) + } +} + +func createTestDatabase() *mem.Database { + db := mem.NewDatabase("test") + table := mem.NewTable("mytable", gitqlsql.Schema{ + gitqlsql.Column{Name: "name", Type: gitqlsql.String}, + gitqlsql.Column{Name: "email", Type: gitqlsql.String}, + }) + db.AddTable("mytable", table) + table.Insert(gitqlsql.NewRow("John Doe", "john@doe.com")) + table.Insert(gitqlsql.NewRow("John Doe", "johnalt@doe.com")) + table.Insert(gitqlsql.NewRow("Jane Doe", "jane@doe.com")) + table.Insert(gitqlsql.NewRow("Evil Bob", "evilbob@gmail.com")) + return db +}