Skip to content

Commit 051905b

Browse files
authored
Merge pull request #668 from skoef/execMultiple
implemented ExecuteMultiple
2 parents 3566d1e + 49a4288 commit 051905b

File tree

5 files changed

+136
-6
lines changed

5 files changed

+136
-6
lines changed

client/conn.go

+66
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
. "github.com/go-mysql-org/go-mysql/mysql"
1212
"github.com/go-mysql-org/go-mysql/packet"
13+
"github.com/go-mysql-org/go-mysql/utils"
1314
"github.com/pingcap/errors"
1415
)
1516

@@ -43,6 +44,9 @@ type SelectPerRowCallback func(row []FieldValue) error
4344
// This function will be called once per result from ExecuteSelectStreaming
4445
type SelectPerResultCallback func(result *Result) error
4546

47+
// This function will be called once per result from ExecuteMultiple
48+
type ExecPerResultCallback func(result *Result, err error)
49+
4650
func getNetProto(addr string) string {
4751
proto := "tcp"
4852
if strings.Contains(addr, "/") {
@@ -198,6 +202,68 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
198202
}
199203
}
200204

205+
// ExecuteMultiple will call perResultCallback for every result of the multiple queries
206+
// that are executed.
207+
//
208+
// When ExecuteMultiple is used, the connection should have the SERVER_MORE_RESULTS_EXISTS
209+
// flag set to signal the server multiple queries are executed. Handling the responses
210+
// is up to the implementation of perResultCallback.
211+
//
212+
// Example:
213+
//
214+
// queries := "SELECT 1; SELECT NOW();"
215+
// conn.ExecuteMultiple(queries, func(result *mysql.Result, err error) {
216+
// // Use the result as you want
217+
// })
218+
//
219+
func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) {
220+
if err := c.writeCommandStr(COM_QUERY, query); err != nil {
221+
return nil, errors.Trace(err)
222+
}
223+
224+
var buf []byte
225+
var err error
226+
var result *Result
227+
defer utils.ByteSlicePut(buf)
228+
229+
for {
230+
buf, err = c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0])
231+
if err != nil {
232+
return nil, errors.Trace(err)
233+
}
234+
235+
switch buf[0] {
236+
case OK_HEADER:
237+
result, err = c.handleOKPacket(buf)
238+
case ERR_HEADER:
239+
err = c.handleErrorPacket(append([]byte{}, buf...))
240+
result = nil
241+
case LocalInFile_HEADER:
242+
err = ErrMalformPacket
243+
result = nil
244+
default:
245+
result, err = c.readResultset(buf, false)
246+
}
247+
248+
// call user-defined callback
249+
perResultCallback(result, err)
250+
251+
// if there was an error of this was the last result, stop looping
252+
if err != nil || result.Status&SERVER_MORE_RESULTS_EXISTS == 0 {
253+
break
254+
}
255+
}
256+
257+
// return an empty result(set) signaling we're done streaming a multiple
258+
// streaming session
259+
// if this would end up in WriteValue, it would just be ignored as all
260+
// responses should have been handled in perResultCallback
261+
return &Result{Resultset: &Resultset{
262+
Streaming: StreamingMultiple,
263+
StreamingDone: true,
264+
}}, nil
265+
}
266+
201267
// ExecuteSelectStreaming will call perRowCallback for every row in resultset
202268
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
203269
// When given, perResultCallback will be called once per result

client/conn_test.go

+45-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package client
22

33
import (
44
"fmt"
5+
"strings"
56

67
. "github.com/pingcap/check"
78

@@ -16,7 +17,10 @@ type connTestSuite struct {
1617
func (s *connTestSuite) SetUpSuite(c *C) {
1718
var err error
1819
addr := fmt.Sprintf("%s:%s", *testHost, s.port)
19-
s.c, err = Connect(addr, *testUser, *testPassword, "")
20+
s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) {
21+
// required for the ExecuteMultiple test
22+
c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS)
23+
})
2024
if err != nil {
2125
c.Fatal(err)
2226
}
@@ -78,6 +82,46 @@ func (s *connTestSuite) testExecute_DropTable(c *C) {
7882
c.Assert(err, IsNil)
7983
}
8084

85+
func (s *connTestSuite) TestExecuteMultiple(c *C) {
86+
queries := []string{
87+
`INSERT INTO ` + testExecuteSelectStreamingTablename + ` (id, str) VALUES (999, "executemultiple")`,
88+
`SELECT id FROM ` + testExecuteSelectStreamingTablename + ` LIMIT 2`,
89+
`DELETE FROM ` + testExecuteSelectStreamingTablename + ` WHERE id=999`,
90+
`THIS IS BOGUS()`,
91+
}
92+
93+
count := 0
94+
result, err := s.c.ExecuteMultiple(strings.Join(queries, "; "), func(result *mysql.Result, err error) {
95+
switch count {
96+
// the INSERT/DELETE query have no resultset, but should have set affectedrows
97+
// the err should be nil
98+
// also, since this is not the last query, the SERVER_MORE_RESULTS_EXISTS
99+
// flag should be set
100+
case 0, 2:
101+
c.Assert(result.Status&mysql.SERVER_MORE_RESULTS_EXISTS, Not(Equals), 0)
102+
c.Assert(result.Resultset, IsNil)
103+
c.Assert(result.AffectedRows, Equals, uint64(1))
104+
c.Assert(err, IsNil)
105+
case 1:
106+
// the SELECT query should have an resultset
107+
// still not the last query, flag should be set
108+
c.Assert(result.Status&mysql.SERVER_MORE_RESULTS_EXISTS, Not(Equals), 0)
109+
c.Assert(result.Resultset, NotNil)
110+
c.Assert(err, IsNil)
111+
case 3:
112+
// this query is obviously bogus so the error should be non-nil
113+
c.Assert(result, IsNil)
114+
c.Assert(err, NotNil)
115+
}
116+
count++
117+
})
118+
119+
c.Assert(count, Equals, 4)
120+
c.Assert(err, IsNil)
121+
c.Assert(result.StreamingDone, Equals, true)
122+
c.Assert(result.Streaming, Equals, mysql.StreamingMultiple)
123+
}
124+
81125
func (s *connTestSuite) TestExecuteSelectStreaming(c *C) {
82126
var (
83127
expectedRowId int64

client/resp.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result,
309309
}
310310

311311
// this is a streaming resultset
312-
result.Resultset.Streaming = true
312+
result.Resultset.Streaming = StreamingSelect
313313

314314
if err := c.readResultColumns(result); err != nil {
315315
return errors.Trace(err)

mysql/resultset.go

+14-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@ import (
99
"github.com/siddontang/go/hack"
1010
)
1111

12+
type StreamingType int
13+
14+
const (
15+
// StreamingNone means there is no streaming
16+
StreamingNone StreamingType = iota
17+
// StreamingSelect is used with select queries for which each result is
18+
// directly returned to the client
19+
StreamingSelect
20+
// StreamingMultiple is used when multiple queries are given at once
21+
// usually in combination with SERVER_MORE_RESULTS_EXISTS flag set
22+
StreamingMultiple
23+
)
24+
1225
type Resultset struct {
1326
Fields []*Field
1427
FieldNames map[string]int
@@ -18,7 +31,7 @@ type Resultset struct {
1831

1932
RowDatas []RowData
2033

21-
Streaming bool
34+
Streaming StreamingType
2235
StreamingDone bool
2336
}
2437

server/resp.go

+10-3
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,15 @@ func (c *Conn) writeResultset(r *Resultset) error {
119119
// for a streaming resultset, that handled rowdata separately in a callback
120120
// of type SelectPerRowCallback, we can suffice by ending the stream with
121121
// an EOF
122+
// when streaming multiple queries, no EOF has to be sent, all results should've
123+
// been taken care of already in the user-defined callback
122124
if r.StreamingDone {
123-
return c.writeEOF()
125+
switch r.Streaming {
126+
case StreamingMultiple:
127+
return nil
128+
case StreamingSelect:
129+
return c.writeEOF()
130+
}
124131
}
125132

126133
columnLen := PutLengthEncodedInt(uint64(len(r.Fields)))
@@ -136,9 +143,9 @@ func (c *Conn) writeResultset(r *Resultset) error {
136143
return err
137144
}
138145

139-
// streaming resultsets handle rowdata in a separate callback of type
146+
// streaming select resultsets handle rowdata in a separate callback of type
140147
// SelectPerRowCallback so we're done here
141-
if r.Streaming {
148+
if r.Streaming == StreamingSelect {
142149
return nil
143150
}
144151

0 commit comments

Comments
 (0)