diff --git a/client/conn.go b/client/conn.go index ae28b2c5d..c5dd5ad2a 100644 --- a/client/conn.go +++ b/client/conn.go @@ -10,6 +10,7 @@ import ( . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" + "github.com/go-mysql-org/go-mysql/utils" "github.com/pingcap/errors" ) @@ -43,6 +44,9 @@ type SelectPerRowCallback func(row []FieldValue) error // This function will be called once per result from ExecuteSelectStreaming type SelectPerResultCallback func(result *Result) error +// This function will be called once per result from ExecuteMultiple +type ExecPerResultCallback func(result *Result, err error) + func getNetProto(addr string) string { proto := "tcp" if strings.Contains(addr, "/") { @@ -198,6 +202,68 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { } } +// ExecuteMultiple will call perResultCallback for every result of the multiple queries +// that are executed. +// +// When ExecuteMultiple is used, the connection should have the SERVER_MORE_RESULTS_EXISTS +// flag set to signal the server multiple queries are executed. Handling the responses +// is up to the implementation of perResultCallback. +// +// Example: +// +// queries := "SELECT 1; SELECT NOW();" +// conn.ExecuteMultiple(queries, func(result *mysql.Result, err error) { +// // Use the result as you want +// }) +// +func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) { + if err := c.writeCommandStr(COM_QUERY, query); err != nil { + return nil, errors.Trace(err) + } + + var buf []byte + var err error + var result *Result + defer utils.ByteSlicePut(buf) + + for { + buf, err = c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0]) + if err != nil { + return nil, errors.Trace(err) + } + + switch buf[0] { + case OK_HEADER: + result, err = c.handleOKPacket(buf) + case ERR_HEADER: + err = c.handleErrorPacket(append([]byte{}, buf...)) + result = nil + case LocalInFile_HEADER: + err = ErrMalformPacket + result = nil + default: + result, err = c.readResultset(buf, false) + } + + // call user-defined callback + perResultCallback(result, err) + + // if there was an error of this was the last result, stop looping + if err != nil || result.Status&SERVER_MORE_RESULTS_EXISTS == 0 { + break + } + } + + // return an empty result(set) signaling we're done streaming a multiple + // streaming session + // if this would end up in WriteValue, it would just be ignored as all + // responses should have been handled in perResultCallback + return &Result{Resultset: &Resultset{ + Streaming: StreamingMultiple, + StreamingDone: true, + }}, nil +} + // ExecuteSelectStreaming will call perRowCallback for every row in resultset // WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields. // When given, perResultCallback will be called once per result diff --git a/client/conn_test.go b/client/conn_test.go index d2896caa0..338bb6ae3 100644 --- a/client/conn_test.go +++ b/client/conn_test.go @@ -2,6 +2,7 @@ package client import ( "fmt" + "strings" . "github.com/pingcap/check" @@ -16,7 +17,10 @@ type connTestSuite struct { func (s *connTestSuite) SetUpSuite(c *C) { var err error addr := fmt.Sprintf("%s:%s", *testHost, s.port) - s.c, err = Connect(addr, *testUser, *testPassword, "") + s.c, err = Connect(addr, *testUser, *testPassword, "", func(c *Conn) { + // required for the ExecuteMultiple test + c.SetCapability(mysql.CLIENT_MULTI_STATEMENTS) + }) if err != nil { c.Fatal(err) } @@ -78,6 +82,46 @@ func (s *connTestSuite) testExecute_DropTable(c *C) { c.Assert(err, IsNil) } +func (s *connTestSuite) TestExecuteMultiple(c *C) { + queries := []string{ + `INSERT INTO ` + testExecuteSelectStreamingTablename + ` (id, str) VALUES (999, "executemultiple")`, + `SELECT id FROM ` + testExecuteSelectStreamingTablename + ` LIMIT 2`, + `DELETE FROM ` + testExecuteSelectStreamingTablename + ` WHERE id=999`, + `THIS IS BOGUS()`, + } + + count := 0 + result, err := s.c.ExecuteMultiple(strings.Join(queries, "; "), func(result *mysql.Result, err error) { + switch count { + // the INSERT/DELETE query have no resultset, but should have set affectedrows + // the err should be nil + // also, since this is not the last query, the SERVER_MORE_RESULTS_EXISTS + // flag should be set + case 0, 2: + c.Assert(result.Status&mysql.SERVER_MORE_RESULTS_EXISTS, Not(Equals), 0) + c.Assert(result.Resultset, IsNil) + c.Assert(result.AffectedRows, Equals, uint64(1)) + c.Assert(err, IsNil) + case 1: + // the SELECT query should have an resultset + // still not the last query, flag should be set + c.Assert(result.Status&mysql.SERVER_MORE_RESULTS_EXISTS, Not(Equals), 0) + c.Assert(result.Resultset, NotNil) + c.Assert(err, IsNil) + case 3: + // this query is obviously bogus so the error should be non-nil + c.Assert(result, IsNil) + c.Assert(err, NotNil) + } + count++ + }) + + c.Assert(count, Equals, 4) + c.Assert(err, IsNil) + c.Assert(result.StreamingDone, Equals, true) + c.Assert(result.Streaming, Equals, mysql.StreamingMultiple) +} + func (s *connTestSuite) TestExecuteSelectStreaming(c *C) { var ( expectedRowId int64 diff --git a/client/resp.go b/client/resp.go index 21d45dab7..0f5215ebf 100644 --- a/client/resp.go +++ b/client/resp.go @@ -309,7 +309,7 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, } // this is a streaming resultset - result.Resultset.Streaming = true + result.Resultset.Streaming = StreamingSelect if err := c.readResultColumns(result); err != nil { return errors.Trace(err) diff --git a/mysql/resultset.go b/mysql/resultset.go index c3ad7ef1b..d90796b83 100644 --- a/mysql/resultset.go +++ b/mysql/resultset.go @@ -9,6 +9,19 @@ import ( "github.com/siddontang/go/hack" ) +type StreamingType int + +const ( + // StreamingNone means there is no streaming + StreamingNone StreamingType = iota + // StreamingSelect is used with select queries for which each result is + // directly returned to the client + StreamingSelect + // StreamingMultiple is used when multiple queries are given at once + // usually in combination with SERVER_MORE_RESULTS_EXISTS flag set + StreamingMultiple +) + type Resultset struct { Fields []*Field FieldNames map[string]int @@ -18,7 +31,7 @@ type Resultset struct { RowDatas []RowData - Streaming bool + Streaming StreamingType StreamingDone bool } diff --git a/server/resp.go b/server/resp.go index 63c64bdea..35b742e9c 100644 --- a/server/resp.go +++ b/server/resp.go @@ -119,8 +119,15 @@ func (c *Conn) writeResultset(r *Resultset) error { // for a streaming resultset, that handled rowdata separately in a callback // of type SelectPerRowCallback, we can suffice by ending the stream with // an EOF + // when streaming multiple queries, no EOF has to be sent, all results should've + // been taken care of already in the user-defined callback if r.StreamingDone { - return c.writeEOF() + switch r.Streaming { + case StreamingMultiple: + return nil + case StreamingSelect: + return c.writeEOF() + } } columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) @@ -136,9 +143,9 @@ func (c *Conn) writeResultset(r *Resultset) error { return err } - // streaming resultsets handle rowdata in a separate callback of type + // streaming select resultsets handle rowdata in a separate callback of type // SelectPerRowCallback so we're done here - if r.Streaming { + if r.Streaming == StreamingSelect { return nil }