Skip to content

Commit 9db12c7

Browse files
authored
Merge pull request #33 from asdine/support-connbegintx
Add support for driver.ConnBeginTx
2 parents fc410b0 + 4803583 commit 9db12c7

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

sqlhooks.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ func (drv *Driver) Open(name string) (driver.Conn, error) {
4949
return conn, err
5050
}
5151

52+
// Drivers that don't implement driver.ConnBeginTx are not supported.
53+
if _, ok := conn.(driver.ConnBeginTx); !ok {
54+
return nil, errors.New("driver must implement driver.ConnBeginTx")
55+
}
56+
5257
wrapped := &Conn{conn, drv.hooks}
5358
if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) {
5459
return &ExecerQueryerContextWithSessionResetter{wrapped,
@@ -97,6 +102,9 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
97102
func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) }
98103
func (conn *Conn) Close() error { return conn.Conn.Close() }
99104
func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() }
105+
func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
106+
return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
107+
}
100108

101109
// ExecerContext implements a database/sql.driver.ExecerContext
102110
type ExecerContext struct {

sqlhooks_interface_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
6363
*FakeConnQueryer
6464
*FakeConnSessionResetter
6565
}{}, nil
66+
case "NonConnBeginTx":
67+
return &FakeConnUnsupported{}, nil
6668
}
6769

6870
return nil, errors.New("Fake driver not implemented")
@@ -80,6 +82,9 @@ func (*FakeConnBasic) Close() error {
8082
func (*FakeConnBasic) Begin() (driver.Tx, error) {
8183
return nil, errors.New("Not implemented")
8284
}
85+
func (*FakeConnBasic) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
86+
return nil, errors.New("Not implemented")
87+
}
8388

8489
type FakeConnExecer struct{}
8590

@@ -111,6 +116,20 @@ func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error {
111116
return errors.New("Not implemented")
112117
}
113118

119+
// FakeConnUnsupported implements a database/sql.driver.Conn but doesn't implement
120+
// driver.ConnBeginTx.
121+
type FakeConnUnsupported struct{}
122+
123+
func (*FakeConnUnsupported) Prepare(query string) (driver.Stmt, error) {
124+
return nil, errors.New("Not implemented")
125+
}
126+
func (*FakeConnUnsupported) Close() error {
127+
return errors.New("Not implemented")
128+
}
129+
func (*FakeConnUnsupported) Begin() (driver.Tx, error) {
130+
return nil, errors.New("Not implemented")
131+
}
132+
114133
func TestInterfaces(t *testing.T) {
115134
drv := Wrap(&fakeDriver{}, &testHooks{})
116135

@@ -123,3 +142,9 @@ func TestInterfaces(t *testing.T) {
123142
}
124143
}
125144
}
145+
146+
func TestUnsupportedDrivers(t *testing.T) {
147+
drv := Wrap(&fakeDriver{}, &testHooks{})
148+
_, err := drv.Open("NonConnBeginTx")
149+
require.EqualError(t, err, "driver must implement driver.ConnBeginTx")
150+
}

0 commit comments

Comments
 (0)