Skip to content

Support MySQL Compressed Protocol #787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ func (c *Conn) writeAuthHandshake() error {
// in the library are supported here
capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE |
c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS |
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS |
c.ccaps&CLIENT_COMPRESS | c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM

// To enable TLS / SSL
if c.tlsConfig != nil {
Expand Down Expand Up @@ -247,6 +248,9 @@ func (c *Conn) writeAuthHandshake() error {
capability |= CLIENT_CONNECT_ATTRS
length += len(attrData)
}
if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
length++
}

data := make([]byte, length+4)

Expand Down Expand Up @@ -320,7 +324,12 @@ func (c *Conn) writeAuthHandshake() error {

// connection attributes
if len(attrData) > 0 {
copy(data[pos:], attrData)
pos += copy(data[pos:], attrData)
}

if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
// zstd_compression_level
data[pos] = 0x03
}

return c.WritePacket(data)
Expand Down
13 changes: 13 additions & 0 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st
return nil, errors.Trace(err)
}

if c.ccaps&CLIENT_COMPRESS > 0 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should notify the user about their priority at the comment of SetCapability

c.Conn.Compression = MYSQL_COMPRESS_ZLIB
} else if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
c.Conn.Compression = MYSQL_COMPRESS_ZSTD
}

return c, nil
}

Expand Down Expand Up @@ -149,6 +155,13 @@ func (c *Conn) Close() error {
return c.Conn.Close()
}

func (c *Conn) Quit() error {
if err := c.writeCommand(COM_QUIT); err != nil {
return err
}
return c.Close()
}

func (c *Conn) Ping() error {
if err := c.writeCommand(COM_PING); err != nil {
return errors.Trace(err)
Expand Down
6 changes: 6 additions & 0 deletions mysql/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,9 @@ const (
MYSQL_OPTION_MULTI_STATEMENTS_ON = iota
MYSQL_OPTION_MULTI_STATEMENTS_OFF
)

const (
MYSQL_COMPRESS_NONE = iota
MYSQL_COMPRESS_ZLIB
MYSQL_COMPRESS_ZSTD
)
150 changes: 138 additions & 12 deletions packet/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package packet
import (
"bufio"
"bytes"
"compress/zlib"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
Expand All @@ -12,6 +13,7 @@ import (
"net"
"sync"

"github.com/DataDog/zstd"
. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/utils"
"github.com/pingcap/errors"
Expand Down Expand Up @@ -56,6 +58,16 @@ type Conn struct {
header [4]byte

Sequence uint8

Compression uint8

CompressedSequence uint8

compressedHeader [7]byte

compressedReaderActive bool

compressedReader io.Reader
}

func NewConn(conn net.Conn) *Conn {
Expand Down Expand Up @@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
utils.BytesBufferPut(buf)
}()

if err := c.ReadPacketTo(buf); err != nil {
return nil, errors.Trace(err)
if c.Compression != MYSQL_COMPRESS_NONE {
if !c.compressedReaderActive {
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
}

compressedSequence := c.compressedHeader[3]
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
if compressedSequence != c.CompressedSequence {
return nil, errors.Errorf("invalid compressed sequence %d != %d",
compressedSequence, c.CompressedSequence)
}

if uncompressedLength > 0 {
var err error
switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
c.compressedReader, err = zlib.NewReader(c.reader)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we unify the c.compressedReader with c.reader by interface?

case MYSQL_COMPRESS_ZSTD:
c.compressedReader = zstd.NewReader(c.reader)
}
if err != nil {
return nil, err
}
}
c.compressedReaderActive = true
}
}

if c.compressedReader != nil {
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
return nil, errors.Trace(err)
}
} else {
if err := c.ReadPacketTo(buf, c.reader); err != nil {
return nil, errors.Trace(err)
}
}

readBytes := buf.Bytes()
Expand Down Expand Up @@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
return written, nil
}

func (c *Conn) ReadPacketTo(w io.Writer) error {
if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil {
func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
if _, err := io.ReadFull(r, c.header[:4]); err != nil {
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
}

Expand All @@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
buf.Grow(length)
}

if n, err := c.copyN(w, c.reader, int64(length)); err != nil {
if n, err := c.copyN(w, r, int64(length)); err != nil {
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
} else if n != int64(length) {
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
Expand All @@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error {
return nil
}

if err := c.ReadPacketTo(w); err != nil {
if err = c.ReadPacketTo(w, r); err != nil {
return errors.Wrap(err, "ReadPacketTo failed")
}
}
Expand Down Expand Up @@ -209,14 +256,93 @@ func (c *Conn) WritePacket(data []byte) error {
data[2] = byte(length >> 16)
data[3] = c.Sequence

if n, err := c.Write(data); err != nil {
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
} else if n != len(data) {
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
switch c.Compression {
case MYSQL_COMPRESS_NONE:
if n, err := c.Write(data); err != nil {
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
} else if n != len(data) {
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
}
case MYSQL_COMPRESS_ZLIB, MYSQL_COMPRESS_ZSTD:
if n, err := c.writeCompressed(data); err != nil {
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err)
} else if n != len(data) {
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
}
c.compressedReader = nil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to set it to nil?

Copy link
Collaborator Author

@dveeden dveeden May 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On line 138:

        if c.compressedReader != nil {
                if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
                        return nil, errors.Trace(err)
                }
        } else {

I could move the c.compressedReader = nil line to here (line 122) if that would make things more obvious:

                        if uncompressedLength > 0 {
                                var err error
                                switch c.Compression {
                                case MYSQL_COMPRESS_ZLIB:
                                        c.compressedReader, err = zlib.NewReader(c.reader)
                                case MYSQL_COMPRESS_ZSTD:
                                        c.compressedReader = zstd.NewReader(c.reader)
                                }
                                if err != nil {
                                        return nil, err
                                }
                        }

The payload of a compressed packet is not always actually compressed. Payloads of less than 50 bytes are often uncompressed as the compression overhead isn't worth it. If the uncompressedLength is 0 then the payload isn't actually compressed.

c.compressedReaderActive = false
default:
return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set")
}

c.Sequence++
return nil
}

func (c *Conn) writeCompressed(data []byte) (n int, err error) {
var compressedLength, uncompressedLength int
var payload, compressedPacket bytes.Buffer
var w io.WriteCloser
minCompressLength := 50
compressedHeader := make([]byte, 7)

switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
w, err = zlib.NewWriterLevel(&payload, zlib.HuffmanOnly)
case MYSQL_COMPRESS_ZSTD:
w = zstd.NewWriter(&payload)
}
if err != nil {
return 0, err
}

if len(data) > minCompressLength {
uncompressedLength = len(data)
n, err = w.Write(data)
if err != nil {
return 0, err
}
err = w.Close()
if err != nil {
return 0, err
}
}

if len(data) > minCompressLength {
compressedLength = len(payload.Bytes())
} else {
compressedLength = len(data)
}

c.CompressedSequence = 0
compressedHeader[0] = byte(compressedLength)
compressedHeader[1] = byte(compressedLength >> 8)
compressedHeader[2] = byte(compressedLength >> 16)
compressedHeader[3] = c.CompressedSequence
compressedHeader[4] = byte(uncompressedLength)
compressedHeader[5] = byte(uncompressedLength >> 8)
compressedHeader[6] = byte(uncompressedLength >> 16)
_, err = compressedPacket.Write(compressedHeader)
if err != nil {
return 0, err
}
c.CompressedSequence++

if len(data) > minCompressLength {
_, err = compressedPacket.Write(payload.Bytes())
} else {
c.Sequence++
return nil
n, err = compressedPacket.Write(data)
}
if err != nil {
return 0, err
}

_, err = c.Write(compressedPacket.Bytes())
if err != nil {
return 0, err
}

return n, nil
}

// WriteClearAuthPacket: Client clear text authentication packet
Expand Down