Skip to content

Commit c613042

Browse files
authored
sql: implement database/sql/driver API, fixes src-d#87 (src-d#108)
1 parent a9fa622 commit c613042

File tree

3 files changed

+284
-39
lines changed

3 files changed

+284
-39
lines changed

engine.go

+176-3
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,55 @@
11
package gitql
22

33
import (
4+
gosql "database/sql"
5+
"database/sql/driver"
6+
"errors"
7+
"fmt"
8+
49
"github.com/gitql/gitql/sql"
510
"github.com/gitql/gitql/sql/analyzer"
611
"github.com/gitql/gitql/sql/expression"
712
"github.com/gitql/gitql/sql/parse"
813
)
914

15+
var (
16+
ErrNotSupported = errors.New("feature not supported yet")
17+
)
18+
19+
const (
20+
DriverName = "gitql"
21+
)
22+
23+
func init() {
24+
gosql.Register(DriverName, defaultDriver)
25+
}
26+
27+
type drv struct{}
28+
29+
var defaultDriver = &drv{}
30+
31+
func (d *drv) Open(name string) (driver.Conn, error) {
32+
if name != "" {
33+
return nil, fmt.Errorf("data source not found: %s", name)
34+
}
35+
36+
e := DefaultEngine
37+
return &session{Engine: e}, nil
38+
}
39+
40+
// DefaultEngine is the default Engine instance, used when opening a connection
41+
// to gitql:// when using database/sql.
42+
var DefaultEngine = New()
43+
44+
// Engine is a SQL engine.
45+
// It implements the standard database/sql/driver/Driver interface, so it can
46+
// be registered as a database/sql driver.
1047
type Engine struct {
1148
Catalog *sql.Catalog
1249
Analyzer *analyzer.Analyzer
1350
}
1451

52+
// New creates a new Engine.
1553
func New() *Engine {
1654
c := sql.NewCatalog()
1755
err := expression.RegisterDefaults(c)
@@ -23,11 +61,15 @@ func New() *Engine {
2361
return &Engine{c, a}
2462
}
2563

26-
func (e *Engine) AddDatabase(db sql.Database) {
27-
e.Catalog.Databases = append(e.Catalog.Databases, db)
28-
e.Analyzer.CurrentDatabase = db.Name()
64+
// Open creates a new session for the engine and returns
65+
// it as a driver.Conn.
66+
//
67+
// Name parameter is ignored.
68+
func (e *Engine) Open(name string) (driver.Conn, error) {
69+
return &session{Engine: e}, nil
2970
}
3071

72+
// Query executes a query without attaching to any session.
3173
func (e *Engine) Query(query string) (sql.Schema, sql.RowIter, error) {
3274
parsed, err := parse.Parse(query)
3375
if err != nil {
@@ -46,3 +88,134 @@ func (e *Engine) Query(query string) (sql.Schema, sql.RowIter, error) {
4688

4789
return analyzed.Schema(), iter, nil
4890
}
91+
92+
func (e *Engine) AddDatabase(db sql.Database) {
93+
e.Catalog.Databases = append(e.Catalog.Databases, db)
94+
e.Analyzer.CurrentDatabase = db.Name()
95+
}
96+
97+
// Session represents a SQL session.
98+
// It implements the standard database/sql/driver/Conn interface.
99+
type session struct {
100+
*Engine
101+
closed bool
102+
//TODO: Current database
103+
}
104+
105+
// Prepare returns a prepared statement, bound to this connection.
106+
// Placeholders are not supported yet.
107+
func (s *session) Prepare(query string) (driver.Stmt, error) {
108+
if err := s.checkOpen(); err != nil {
109+
return nil, err
110+
}
111+
112+
return &stmt{session: s, query: query}, nil
113+
}
114+
115+
// Close closes the session.
116+
func (s *session) Close() error {
117+
if err := s.checkOpen(); err != nil {
118+
return err
119+
}
120+
121+
s.closed = true
122+
return nil
123+
}
124+
125+
// Begin starts and returns a new transaction.
126+
func (s *session) Begin() (driver.Tx, error) {
127+
return nil, fmt.Errorf("transactions not supported")
128+
}
129+
130+
func (s *session) checkOpen() error {
131+
if s.closed {
132+
return driver.ErrBadConn
133+
}
134+
135+
return nil
136+
}
137+
138+
type stmt struct {
139+
*session
140+
query string
141+
closed bool
142+
}
143+
144+
// Close closes the statement.
145+
func (s *stmt) Close() error {
146+
if err := s.checkOpen(); err != nil {
147+
return err
148+
}
149+
150+
s.closed = true
151+
return nil
152+
}
153+
154+
// NumInput returns the number of placeholder parameters.
155+
// Always returns -1 since placeholders are not supported yet.
156+
func (s *stmt) NumInput() int {
157+
return -1
158+
}
159+
160+
// Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
161+
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
162+
return nil, ErrNotSupported
163+
}
164+
165+
// Query executes a query that may return rows, such as a SELECT.
166+
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
167+
if len(args) > 0 {
168+
return nil, ErrNotSupported
169+
}
170+
171+
schema, iter, err := s.session.Engine.Query(s.query)
172+
if err != nil {
173+
return nil, err
174+
}
175+
176+
return &rows{schema: schema, iter: iter}, nil
177+
}
178+
179+
func (s *stmt) checkOpen() error {
180+
if s.closed {
181+
return driver.ErrBadConn
182+
}
183+
184+
return nil
185+
}
186+
187+
type rows struct {
188+
schema sql.Schema
189+
iter sql.RowIter
190+
}
191+
192+
// Columns returns the names of the columns.
193+
func (rs *rows) Columns() []string {
194+
c := make([]string, len(rs.schema))
195+
for i := 0; i < len(rs.schema); i++ {
196+
c[i] = rs.schema[i].Name
197+
}
198+
199+
return c
200+
}
201+
202+
// Close closes the rows iterator.
203+
func (rs *rows) Close() error {
204+
return rs.iter.Close()
205+
}
206+
207+
// Next populates the given array with the next row values.
208+
// Returns io.EOF when there are no more values.
209+
func (rs *rows) Next(dest []driver.Value) error {
210+
r, err := rs.iter.Next()
211+
if err != nil {
212+
return err
213+
}
214+
215+
for i := range dest {
216+
f := rs.schema[i]
217+
dest[i] = f.Type.Native(r[i])
218+
}
219+
220+
return nil
221+
}

engine_test.go

+42-36
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package gitql_test
22

33
import (
4-
"io"
4+
gosql "database/sql"
55
"testing"
66

77
"github.com/gitql/gitql"
@@ -11,77 +11,83 @@ import (
1111
"github.com/stretchr/testify/require"
1212
)
1313

14+
const (
15+
driverName = "engine_tests"
16+
)
17+
1418
func TestEngine_Query(t *testing.T) {
1519
e := newEngine(t)
20+
gosql.Register(driverName, e)
1621
testQuery(t, e,
1722
"SELECT i FROM mytable;",
18-
[]sql.Row{
19-
sql.NewRow(int64(1)),
20-
sql.NewRow(int64(2)),
21-
sql.NewRow(int64(3)),
22-
},
23+
[][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}},
2324
)
2425

2526
testQuery(t, e,
2627
"SELECT i FROM mytable WHERE i = 2;",
27-
[]sql.Row{
28-
sql.NewRow(int64(2)),
29-
},
28+
[][]interface{}{{int64(2)}},
3029
)
3130

3231
testQuery(t, e,
3332
"SELECT i FROM mytable ORDER BY i DESC;",
34-
[]sql.Row{
35-
sql.NewRow(int64(3)),
36-
sql.NewRow(int64(2)),
37-
sql.NewRow(int64(1)),
38-
},
33+
[][]interface{}{{int64(3)}, {int64(2)}, {int64(1)}},
3934
)
4035

4136
testQuery(t, e,
4237
"SELECT i FROM mytable WHERE s = 'a' ORDER BY i DESC;",
43-
[]sql.Row{
44-
sql.NewRow(int64(1)),
45-
},
38+
[][]interface{}{{int64(1)}},
4639
)
4740

4841
testQuery(t, e,
4942
"SELECT i FROM mytable WHERE s = 'a' ORDER BY i DESC LIMIT 1;",
50-
[]sql.Row{
51-
sql.NewRow(int64(1)),
52-
},
43+
[][]interface{}{{int64(1)}},
5344
)
5445

5546
testQuery(t, e,
5647
"SELECT COUNT(*) FROM mytable;",
57-
[]sql.Row{
58-
sql.NewRow(int32(3)),
59-
},
48+
[][]interface{}{{int64(3)}},
6049
)
6150
}
6251

63-
func testQuery(t *testing.T, e *gitql.Engine, q string, r []sql.Row) {
52+
func testQuery(t *testing.T, e *gitql.Engine, q string, r [][]interface{}) {
6453
assert := require.New(t)
6554

66-
schema, iter, err := e.Query(q)
67-
assert.Nil(err)
68-
assert.NotNil(iter)
69-
assert.NotNil(schema)
55+
db, err := gosql.Open(driverName, "")
56+
assert.NoError(err)
57+
defer func() { assert.NoError(db.Close()) }()
58+
59+
res, err := db.Query(q)
60+
assert.NoError(err)
61+
defer func() { assert.NoError(res.Close()) }()
7062

71-
results := []sql.Row{}
63+
cols, err := res.Columns()
64+
assert.NoError(err)
65+
assert.Equal(len(r[0]), len(cols))
66+
67+
i := 0
7268
for {
73-
el, err := iter.Next()
74-
if err == io.EOF {
69+
if !res.Next() {
7570
break
7671
}
77-
if err != nil {
78-
assert.Fail("returned err distinct of io.EOF: %q", err)
72+
73+
expectedRow := r[i]
74+
i++
75+
76+
row := make([]interface{}, len(expectedRow))
77+
for i := range row {
78+
i64 := int64(0)
79+
row[i] = &i64
80+
}
81+
82+
assert.NoError(res.Scan(row...))
83+
for i := range row {
84+
row[i] = *(row[i].(*int64))
7985
}
80-
results = append(results, el)
86+
87+
assert.Equal(expectedRow, row)
8188
}
8289

83-
assert.Len(results, len(r))
84-
assert.Equal(results, r)
90+
assert.Equal(len(r), i)
8591
}
8692

8793
func newEngine(t *testing.T) *gitql.Engine {

0 commit comments

Comments
 (0)