Skip to content
This repository was archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

Commit 5750e3f

Browse files
authored
Add context support (#1193)
* add context support * improve pingcontext tests
1 parent 229c3aa commit 5750e3f

11 files changed

+100
-27
lines changed

engine.go

+3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package xorm
77
import (
88
"bufio"
99
"bytes"
10+
"context"
1011
"database/sql"
1112
"encoding/gob"
1213
"errors"
@@ -52,6 +53,8 @@ type Engine struct {
5253

5354
cachers map[string]core.Cacher
5455
cacherLock sync.RWMutex
56+
57+
defaultContext context.Context
5558
}
5659

5760
func (engine *Engine) setCacher(tableName string, cacher core.Cacher) {

engine_context.go

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright 2019 The Xorm Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// +build go1.8
6+
7+
package xorm
8+
9+
import "context"
10+
11+
// Context creates a session with the context
12+
func (engine *Engine) Context(ctx context.Context) *Session {
13+
session := engine.NewSession()
14+
session.isAutoClose = true
15+
return session.Context(ctx)
16+
}
17+
18+
// SetDefaultContext set the default context
19+
func (engine *Engine) SetDefaultContext(ctx context.Context) {
20+
engine.defaultContext = ctx
21+
}
22+
23+
// PingContext tests if database is alive
24+
func (engine *Engine) PingContext(ctx context.Context) error {
25+
session := engine.NewSession()
26+
defer session.Close()
27+
return session.PingContext(ctx)
28+
}

context_test.go renamed to engine_context_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@ import (
1717
func TestPingContext(t *testing.T) {
1818
assert.NoError(t, prepareEngine())
1919

20-
ctx, canceled := context.WithTimeout(context.Background(), 10*time.Second)
20+
ctx, canceled := context.WithTimeout(context.Background(), time.Nanosecond)
2121
defer canceled()
2222

2323
err := testEngine.(*Engine).PingContext(ctx)
24-
assert.NoError(t, err)
24+
assert.Error(t, err)
25+
assert.Contains(t, err.Error(), "context deadline exceeded")
2526
}

interface.go

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package xorm
66

77
import (
8+
"context"
89
"database/sql"
910
"reflect"
1011
"time"
@@ -73,6 +74,7 @@ type EngineInterface interface {
7374
Before(func(interface{})) *Session
7475
Charset(charset string) *Session
7576
ClearCache(...interface{}) error
77+
Context(context.Context) *Session
7678
CreateTables(...interface{}) error
7779
DBMetas() ([]*core.Table, error)
7880
Dialect() core.Dialect

session.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package xorm
66

77
import (
8+
"context"
89
"database/sql"
910
"encoding/json"
1011
"errors"
@@ -52,6 +53,7 @@ type Session struct {
5253
lastSQLArgs []interface{}
5354

5455
err error
56+
ctx context.Context
5557
}
5658

5759
// Clone copy all the session's content and return a new session
@@ -82,6 +84,8 @@ func (session *Session) Init() {
8284

8385
session.lastSQL = ""
8486
session.lastSQLArgs = []interface{}{}
87+
88+
session.ctx = session.engine.defaultContext
8589
}
8690

8791
// Close release the connection from pool
@@ -275,7 +279,7 @@ func (session *Session) doPrepare(db *core.DB, sqlStr string) (stmt *core.Stmt,
275279
var has bool
276280
stmt, has = session.stmtCache[crc]
277281
if !has {
278-
stmt, err = db.Prepare(sqlStr)
282+
stmt, err = db.PrepareContext(session.ctx, sqlStr)
279283
if err != nil {
280284
return nil, err
281285
}

context.go renamed to session_context.go

+5-8
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
// Copyright 2017 The Xorm Authors. All rights reserved.
1+
// Copyright 2019 The Xorm Authors. All rights reserved.
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44

5-
// +build go1.8
6-
75
package xorm
86

97
import "context"
108

11-
// PingContext tests if database is alive
12-
func (engine *Engine) PingContext(ctx context.Context) error {
13-
session := engine.NewSession()
14-
defer session.Close()
15-
return session.PingContext(ctx)
9+
// Context sets the context on this session
10+
func (session *Session) Context(ctx context.Context) *Session {
11+
session.ctx = ctx
12+
return session
1613
}
1714

1815
// PingContext test if database is ok

session_context_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright 2019 The Xorm Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package xorm
6+
7+
import (
8+
"context"
9+
"testing"
10+
"time"
11+
12+
"github.com/stretchr/testify/assert"
13+
)
14+
15+
func TestQueryContext(t *testing.T) {
16+
type ContextQueryStruct struct {
17+
Id int64
18+
Name string
19+
}
20+
21+
assert.NoError(t, prepareEngine())
22+
assertSync(t, new(ContextQueryStruct))
23+
24+
_, err := testEngine.Insert(&ContextQueryStruct{Name: "1"})
25+
assert.NoError(t, err)
26+
27+
sess := testEngine.NewSession()
28+
defer sess.Close()
29+
30+
ctx, cancel := context.WithTimeout(context.Background(), time.Microsecond)
31+
defer cancel()
32+
has, err := testEngine.Context(ctx).Exist(&ContextQueryStruct{Name: "1"})
33+
assert.Error(t, err)
34+
assert.Contains(t, err.Error(), "context deadline exceeded")
35+
assert.False(t, has)
36+
}

session_raw.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,21 @@ func (session *Session) queryRows(sqlStr string, args ...interface{}) (*core.Row
6262
return nil, err
6363
}
6464

65-
rows, err := stmt.Query(args...)
65+
rows, err := stmt.QueryContext(session.ctx, args...)
6666
if err != nil {
6767
return nil, err
6868
}
6969
return rows, nil
7070
}
7171

72-
rows, err := db.Query(sqlStr, args...)
72+
rows, err := db.QueryContext(session.ctx, sqlStr, args...)
7373
if err != nil {
7474
return nil, err
7575
}
7676
return rows, nil
7777
}
7878

79-
rows, err := session.tx.Query(sqlStr, args...)
79+
rows, err := session.tx.QueryContext(session.ctx, sqlStr, args...)
8080
if err != nil {
8181
return nil, err
8282
}
@@ -175,7 +175,7 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
175175
}
176176

177177
if !session.isAutoCommit {
178-
return session.tx.Exec(sqlStr, args...)
178+
return session.tx.ExecContext(session.ctx, sqlStr, args...)
179179
}
180180

181181
if session.prepareStmt {
@@ -184,14 +184,14 @@ func (session *Session) exec(sqlStr string, args ...interface{}) (sql.Result, er
184184
return nil, err
185185
}
186186

187-
res, err := stmt.Exec(args...)
187+
res, err := stmt.ExecContext(session.ctx, args...)
188188
if err != nil {
189189
return nil, err
190190
}
191191
return res, nil
192192
}
193193

194-
return session.DB().Exec(sqlStr, args...)
194+
return session.DB().ExecContext(session.ctx, sqlStr, args...)
195195
}
196196

197197
func convertSQLOrArgs(sqlorArgs ...interface{}) (string, []interface{}, error) {

session_schema.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ func (session *Session) Ping() error {
1919
}
2020

2121
session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
22-
return session.DB().Ping()
22+
return session.DB().PingContext(session.ctx)
2323
}
2424

2525
// CreateTable create a table according a bean

session_tx.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ package xorm
77
// Begin a transaction
88
func (session *Session) Begin() error {
99
if session.isAutoCommit {
10-
tx, err := session.DB().Begin()
10+
tx, err := session.DB().BeginTx(session.ctx, nil)
1111
if err != nil {
1212
return err
1313
}

xorm.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
package xorm
88

99
import (
10+
"context"
1011
"fmt"
1112
"os"
1213
"reflect"
@@ -85,14 +86,15 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
8586
}
8687

8788
engine := &Engine{
88-
db: db,
89-
dialect: dialect,
90-
Tables: make(map[reflect.Type]*core.Table),
91-
mutex: &sync.RWMutex{},
92-
TagIdentifier: "xorm",
93-
TZLocation: time.Local,
94-
tagHandlers: defaultTagHandlers,
95-
cachers: make(map[string]core.Cacher),
89+
db: db,
90+
dialect: dialect,
91+
Tables: make(map[reflect.Type]*core.Table),
92+
mutex: &sync.RWMutex{},
93+
TagIdentifier: "xorm",
94+
TZLocation: time.Local,
95+
tagHandlers: defaultTagHandlers,
96+
cachers: make(map[string]core.Cacher),
97+
defaultContext: context.Background(),
9698
}
9799

98100
if uri.DbType == core.SQLITE {

0 commit comments

Comments
 (0)