Skip to content

SELECT streaming #560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,27 @@ Tested MySQL versions for the client include:
- 5.7.x
- 8.0.x

### Example for SELECT streaming (v.1.1.1)
You can use also streaming for large SELECT responses.
The callback function will be called for every result row without storing the whole resultset in memory.
`result.Fields` will be filled before the first callback call.

```go
// ...
var result mysql.Result
err := conn.ExecuteSelectStreaming(`select id, name from table LIMIT 100500`, &result, func(row []mysql.FieldValue) error {
for idx, val := range row {
field := result.Fields[idx]
// You must not save FieldValue.AsString() value after this callback is done.
// Copy it if you need.
// ...
}
return false, nil
})

// ...
```

## Server

Server package supplies a framework to implement a simple MySQL server which can handle the packets from the MySQL client.
Expand Down
25 changes: 25 additions & 0 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ type Conn struct {
connectionID uint32
}

// This function will be called for every row in resultset from ExecuteSelectStreaming.
type SelectPerRowCallback func(row []FieldValue) error

func getNetProto(addr string) string {
proto := "tcp"
if strings.Contains(addr, "/") {
Expand Down Expand Up @@ -165,6 +168,28 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) {
}
}

// ExecuteSelectStreaming will call perRowCallback for every row in resultset
// WITHOUT saving any row data to Result.{Values/RawPkg/RowDatas} fields.
//
// ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving.
//
// Example:
//
// var result mysql.Result
// conn.ExecuteSelectStreaming(`SELECT ... LIMIT 100500`, &result, func(row []mysql.FieldValue) error {
// // Use the row as you want.
// // You must not save FieldValue.AsString() value after this callback is done. Copy it if you need.
// return nil
// })
//
func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback) error {
if err := c.writeCommandStr(COM_QUERY, command); err != nil {
return errors.Trace(err)
}

return c.readResultStreaming(false, result, perRowCallback)
}

func (c *Conn) Begin() error {
_, err := c.exec("BEGIN")
return errors.Trace(err)
Expand Down
88 changes: 88 additions & 0 deletions client/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,25 @@ func (c *Conn) readResult(binary bool) (*Result, error) {
return c.readResultset(firstPkgBuf, binary)
}

func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback) error {
firstPkgBuf, err := c.ReadPacketReuseMem(utils.ByteSliceGet(16)[:0])
defer utils.ByteSlicePut(firstPkgBuf)

if err != nil {
return errors.Trace(err)
}

if firstPkgBuf[0] == OK_HEADER {
return ErrMalformPacket // Streaming allowed only for SELECT queries
} else if firstPkgBuf[0] == ERR_HEADER {
return c.handleErrorPacket(append([]byte{}, firstPkgBuf...))
} else if firstPkgBuf[0] == LocalInFile_HEADER {
return ErrMalformPacket
}

return c.readResultsetStreaming(firstPkgBuf, binary, result, perRowCb)
}

func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
// column count
count, _, n := LengthEncodedInt(data)
Expand All @@ -256,6 +275,31 @@ func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) {
return result, nil
}

func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback) error {
columnCount, _, n := LengthEncodedInt(data)

if n-len(data) != 0 {
return ErrMalformPacket
}

if result.Resultset == nil {
result.Resultset = NewResultset(int(columnCount))
} else {
// Reuse memory if can
result.Reset(int(columnCount))
}

if err := c.readResultColumns(result); err != nil {
return errors.Trace(err)
}

if err := c.readResultRowsStreaming(result, binary, perRowCb); err != nil {
return errors.Trace(err)
}

return nil
}

func (c *Conn) readResultColumns(result *Result) (err error) {
var i int = 0
var data []byte
Expand Down Expand Up @@ -344,3 +388,47 @@ func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) {

return nil
}

func (c *Conn) readResultRowsStreaming(result *Result, isBinary bool, perRowCb SelectPerRowCallback) (err error) {
var (
data []byte
row []FieldValue
)

for {
data, err = c.ReadPacketReuseMem(data[:0])
if err != nil {
return
}

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&CLIENT_PROTOCOL_41 > 0 {
// result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}

break
}

if data[0] == ERR_HEADER {
return c.handleErrorPacket(data)
}

// Parse this row
row, err = RowData(data).Parse(result.Fields, isBinary, row)
if err != nil {
return errors.Trace(err)
}

// Send the row to "userland" code
err = perRowCb(row)
if err != nil {
return errors.Trace(err)
}
}

return nil
}
14 changes: 7 additions & 7 deletions mysql/resultset.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ var (
}
)

func NewResultset(resultsetCount int) *Resultset {
func NewResultset(fieldsCount int) *Resultset {
r := resultsetPool.Get().(*Resultset)
r.reset(resultsetCount)
r.Reset(fieldsCount)
return r
}

func (r *Resultset) returnToPool() {
resultsetPool.Put(r)
}

func (r *Resultset) reset(count int) {
func (r *Resultset) Reset(fieldsCount int) {
r.RawPkg = r.RawPkg[:0]

r.Fields = r.Fields[:0]
Expand All @@ -52,14 +52,14 @@ func (r *Resultset) reset(count int) {
r.FieldNames = make(map[string]int)
}

if count == 0 {
if fieldsCount == 0 {
return
}

if cap(r.Fields) < count {
r.Fields = make([]*Field, count)
if cap(r.Fields) < fieldsCount {
r.Fields = make([]*Field, fieldsCount)
} else {
r.Fields = r.Fields[:count]
r.Fields = r.Fields[:fieldsCount]
}
}

Expand Down