From 726392540d4b5f2779d1d843816a69f48139d3ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 5 May 2023 12:37:28 +0200 Subject: [PATCH 1/3] Support MySQL Compressed Protocol --- client/auth.go | 14 ++++- client/conn.go | 13 +++++ mysql/const.go | 6 ++ packet/conn.go | 152 +++++++++++++++++++++++++++++++++++++++++++++---- 4 files changed, 171 insertions(+), 14 deletions(-) diff --git a/client/auth.go b/client/auth.go index ff7ebcd01..a4297e7e4 100644 --- a/client/auth.go +++ b/client/auth.go @@ -201,7 +201,8 @@ func (c *Conn) writeAuthHandshake() error { // in the library are supported here capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE | c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS | - c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS + c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS | + c.ccaps&CLIENT_COMPRESS | c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM // To enable TLS / SSL if c.tlsConfig != nil { @@ -247,6 +248,9 @@ func (c *Conn) writeAuthHandshake() error { capability |= CLIENT_CONNECT_ATTRS length += len(attrData) } + if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + length++ + } data := make([]byte, length+4) @@ -320,7 +324,13 @@ func (c *Conn) writeAuthHandshake() error { // connection attributes if len(attrData) > 0 { - copy(data[pos:], attrData) + pos += copy(data[pos:], attrData) + } + + if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + // zstd_compression_level + data[pos] = 0x03 + pos++ } return c.WritePacket(data) diff --git a/client/conn.go b/client/conn.go index 8b2940ee8..22aa10629 100644 --- a/client/conn.go +++ b/client/conn.go @@ -112,6 +112,12 @@ func ConnectWithDialer(ctx context.Context, network string, addr string, user st return nil, errors.Trace(err) } + if c.ccaps&CLIENT_COMPRESS > 0 { + c.Conn.Compression = MYSQL_COMPRESS_ZLIB + } else if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + c.Conn.Compression = MYSQL_COMPRESS_ZSTD + } + return c, nil } @@ -140,6 +146,13 @@ func (c *Conn) Close() error { return c.Conn.Close() } +func (c *Conn) Quit() error { + if err := c.writeCommand(COM_QUIT); err != nil { + return err + } + return c.Close() +} + func (c *Conn) Ping() error { if err := c.writeCommand(COM_PING); err != nil { return errors.Trace(err) diff --git a/mysql/const.go b/mysql/const.go index a1a5bde42..34661294a 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -185,3 +185,9 @@ const ( MYSQL_OPTION_MULTI_STATEMENTS_ON = iota MYSQL_OPTION_MULTI_STATEMENTS_OFF ) + +const ( + MYSQL_COMPRESS_NONE = iota + MYSQL_COMPRESS_ZLIB + MYSQL_COMPRESS_ZSTD +) diff --git a/packet/conn.go b/packet/conn.go index 8d020fe92..e77718a24 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -3,6 +3,7 @@ package packet import ( "bufio" "bytes" + "compress/zlib" "crypto/rand" "crypto/rsa" "crypto/sha1" @@ -12,6 +13,7 @@ import ( "net" "sync" + "github.com/DataDog/zstd" . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" "github.com/pingcap/errors" @@ -56,6 +58,16 @@ type Conn struct { header [4]byte Sequence uint8 + + Compression uint8 + + CompressedSequence uint8 + + compressedHeader [7]byte + + compressedReaderActive bool + + compressedReader io.Reader } func NewConn(conn net.Conn) *Conn { @@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { utils.BytesBufferPut(buf) }() - if err := c.ReadPacketTo(buf); err != nil { - return nil, errors.Trace(err) + if c.Compression != MYSQL_COMPRESS_NONE { + if !c.compressedReaderActive { + if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { + return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) + } + + compressedSequence := c.compressedHeader[3] + uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16) + if compressedSequence != c.CompressedSequence { + return nil, errors.Errorf("invalid compressed sequence %d != %d", + compressedSequence, c.CompressedSequence) + } + + if uncompressedLength > 0 { + var err error + switch c.Compression { + case MYSQL_COMPRESS_ZLIB: + c.compressedReader, err = zlib.NewReader(c.reader) + case MYSQL_COMPRESS_ZSTD: + c.compressedReader = zstd.NewReader(c.reader) + } + if err != nil { + return nil, err + } + } + c.compressedReaderActive = true + } + } + + if c.compressedReader != nil { + if err := c.ReadPacketTo(buf, c.compressedReader); err != nil { + return nil, errors.Trace(err) + } + } else { + if err := c.ReadPacketTo(buf, c.reader); err != nil { + return nil, errors.Trace(err) + } } readBytes := buf.Bytes() @@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err return written, nil } -func (c *Conn) ReadPacketTo(w io.Writer) error { - if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil { +func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { + if _, err := io.ReadFull(r, c.header[:4]); err != nil { return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err) } @@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { buf.Grow(length) } - if n, err := c.copyN(w, c.reader, int64(length)); err != nil { + if n, err := c.copyN(w, r, int64(length)); err != nil { return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length) } else if n != int64(length) { return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length) @@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { return nil } - if err := c.ReadPacketTo(w); err != nil { + if err = c.ReadPacketTo(w, r); err != nil { return errors.Wrap(err, "ReadPacketTo failed") } } @@ -209,14 +256,95 @@ func (c *Conn) WritePacket(data []byte) error { data[2] = byte(length >> 16) data[3] = c.Sequence - if n, err := c.Write(data); err != nil { - return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) - } else if n != len(data) { - return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) + switch c.Compression { + case MYSQL_COMPRESS_NONE: + if n, err := c.Write(data); err != nil { + return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) + } else if n != len(data) { + return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) + } + case MYSQL_COMPRESS_ZLIB: + fallthrough + case MYSQL_COMPRESS_ZSTD: + if n, err := c.writeCompressed(data); err != nil { + return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) + } else if n != len(data) { + return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) + } + c.compressedReader = nil + c.compressedReaderActive = false + default: + return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set") + } + + c.Sequence++ + return nil +} + +func (c *Conn) writeCompressed(data []byte) (n int, err error) { + var compressedLength, uncompressedLength int + var payload, compressedPacket bytes.Buffer + var w io.WriteCloser + minCompressLength := 50 + compressedHeader := make([]byte, 7) + + switch c.Compression { + case MYSQL_COMPRESS_ZLIB: + w, err = zlib.NewWriterLevel(&payload, zlib.HuffmanOnly) + case MYSQL_COMPRESS_ZSTD: + w = zstd.NewWriter(&payload) + } + if err != nil { + return 0, err + } + + if len(data) > minCompressLength { + uncompressedLength = len(data) + n, err = w.Write(data) + if err != nil { + return 0, err + } + err = w.Close() + if err != nil { + return 0, err + } + } + + if len(data) > minCompressLength { + compressedLength = len(payload.Bytes()) + } else { + compressedLength = len(data) + } + + c.CompressedSequence = 0 + compressedHeader[0] = byte(compressedLength) + compressedHeader[1] = byte(compressedLength >> 8) + compressedHeader[2] = byte(compressedLength >> 16) + compressedHeader[3] = c.CompressedSequence + compressedHeader[4] = byte(uncompressedLength) + compressedHeader[5] = byte(uncompressedLength >> 8) + compressedHeader[6] = byte(uncompressedLength >> 16) + _, err = compressedPacket.Write(compressedHeader) + if err != nil { + return 0, err + } + c.CompressedSequence++ + + if len(data) > minCompressLength { + _, err = compressedPacket.Write(payload.Bytes()) } else { - c.Sequence++ - return nil + n, err = compressedPacket.Write(data) + } + if err != nil { + return 0, err } + + _, err = c.Write(compressedPacket.Bytes()) + if err != nil { + return 0, err + } + + return n, nil } // WriteClearAuthPacket: Client clear text authentication packet From 2e344337d18da07adaf0556c92184798cb1915ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 5 May 2023 14:33:00 +0200 Subject: [PATCH 2/3] Fix ineffassign --- client/auth.go | 1 - 1 file changed, 1 deletion(-) diff --git a/client/auth.go b/client/auth.go index a4297e7e4..e9c25e78b 100644 --- a/client/auth.go +++ b/client/auth.go @@ -330,7 +330,6 @@ func (c *Conn) writeAuthHandshake() error { if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { // zstd_compression_level data[pos] = 0x03 - pos++ } return c.WritePacket(data) From 419796ecc7f9d17cb9beb3002a4f3f23c7c8a342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Sun, 21 May 2023 14:46:14 -0600 Subject: [PATCH 3/3] Update based on review --- packet/conn.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packet/conn.go b/packet/conn.go index e77718a24..963b89e98 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -263,9 +263,7 @@ func (c *Conn) WritePacket(data []byte) error { } else if n != len(data) { return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) } - case MYSQL_COMPRESS_ZLIB: - fallthrough - case MYSQL_COMPRESS_ZSTD: + case MYSQL_COMPRESS_ZLIB, MYSQL_COMPRESS_ZSTD: if n, err := c.writeCompressed(data); err != nil { return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) } else if n != len(data) {