-
Notifications
You must be signed in to change notification settings - Fork 1k
Support MySQL Compressed Protocol #787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ package packet | |
import ( | ||
"bufio" | ||
"bytes" | ||
"compress/zlib" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/sha1" | ||
|
@@ -12,6 +13,7 @@ import ( | |
"net" | ||
"sync" | ||
|
||
"github.com/DataDog/zstd" | ||
. "github.com/go-mysql-org/go-mysql/mysql" | ||
"github.com/go-mysql-org/go-mysql/utils" | ||
"github.com/pingcap/errors" | ||
|
@@ -56,6 +58,16 @@ type Conn struct { | |
header [4]byte | ||
|
||
Sequence uint8 | ||
|
||
Compression uint8 | ||
|
||
CompressedSequence uint8 | ||
|
||
compressedHeader [7]byte | ||
|
||
compressedReaderActive bool | ||
|
||
compressedReader io.Reader | ||
} | ||
|
||
func NewConn(conn net.Conn) *Conn { | ||
|
@@ -94,8 +106,43 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { | |
utils.BytesBufferPut(buf) | ||
}() | ||
|
||
if err := c.ReadPacketTo(buf); err != nil { | ||
return nil, errors.Trace(err) | ||
if c.Compression != MYSQL_COMPRESS_NONE { | ||
if !c.compressedReaderActive { | ||
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { | ||
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) | ||
} | ||
|
||
compressedSequence := c.compressedHeader[3] | ||
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16) | ||
if compressedSequence != c.CompressedSequence { | ||
return nil, errors.Errorf("invalid compressed sequence %d != %d", | ||
compressedSequence, c.CompressedSequence) | ||
} | ||
|
||
if uncompressedLength > 0 { | ||
var err error | ||
switch c.Compression { | ||
case MYSQL_COMPRESS_ZLIB: | ||
c.compressedReader, err = zlib.NewReader(c.reader) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we unify the |
||
case MYSQL_COMPRESS_ZSTD: | ||
c.compressedReader = zstd.NewReader(c.reader) | ||
} | ||
if err != nil { | ||
return nil, err | ||
} | ||
} | ||
c.compressedReaderActive = true | ||
} | ||
} | ||
|
||
if c.compressedReader != nil { | ||
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
} else { | ||
if err := c.ReadPacketTo(buf, c.reader); err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
} | ||
|
||
readBytes := buf.Bytes() | ||
|
@@ -145,8 +192,8 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err | |
return written, nil | ||
} | ||
|
||
func (c *Conn) ReadPacketTo(w io.Writer) error { | ||
if _, err := io.ReadFull(c.reader, c.header[:4]); err != nil { | ||
func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error { | ||
if _, err := io.ReadFull(r, c.header[:4]); err != nil { | ||
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err) | ||
} | ||
|
||
|
@@ -164,7 +211,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { | |
buf.Grow(length) | ||
} | ||
|
||
if n, err := c.copyN(w, c.reader, int64(length)); err != nil { | ||
if n, err := c.copyN(w, r, int64(length)); err != nil { | ||
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length) | ||
} else if n != int64(length) { | ||
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length) | ||
|
@@ -173,7 +220,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { | |
return nil | ||
} | ||
|
||
if err := c.ReadPacketTo(w); err != nil { | ||
if err = c.ReadPacketTo(w, r); err != nil { | ||
return errors.Wrap(err, "ReadPacketTo failed") | ||
} | ||
} | ||
|
@@ -209,14 +256,95 @@ func (c *Conn) WritePacket(data []byte) error { | |
data[2] = byte(length >> 16) | ||
data[3] = c.Sequence | ||
|
||
if n, err := c.Write(data); err != nil { | ||
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) | ||
} else if n != len(data) { | ||
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) | ||
switch c.Compression { | ||
case MYSQL_COMPRESS_NONE: | ||
if n, err := c.Write(data); err != nil { | ||
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) | ||
} else if n != len(data) { | ||
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) | ||
} | ||
case MYSQL_COMPRESS_ZLIB: | ||
fallthrough | ||
case MYSQL_COMPRESS_ZSTD: | ||
dveeden marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if n, err := c.writeCompressed(data); err != nil { | ||
return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) | ||
} else if n != len(data) { | ||
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) | ||
} | ||
c.compressedReader = nil | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to set it to nil? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On line 138: if c.compressedReader != nil {
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
return nil, errors.Trace(err)
}
} else { I could move the if uncompressedLength > 0 {
var err error
switch c.Compression {
case MYSQL_COMPRESS_ZLIB:
c.compressedReader, err = zlib.NewReader(c.reader)
case MYSQL_COMPRESS_ZSTD:
c.compressedReader = zstd.NewReader(c.reader)
}
if err != nil {
return nil, err
}
} The payload of a compressed packet is not always actually compressed. Payloads of less than 50 bytes are often uncompressed as the compression overhead isn't worth it. If the |
||
c.compressedReaderActive = false | ||
default: | ||
return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set") | ||
} | ||
|
||
c.Sequence++ | ||
return nil | ||
} | ||
|
||
func (c *Conn) writeCompressed(data []byte) (n int, err error) { | ||
var compressedLength, uncompressedLength int | ||
var payload, compressedPacket bytes.Buffer | ||
var w io.WriteCloser | ||
minCompressLength := 50 | ||
compressedHeader := make([]byte, 7) | ||
|
||
switch c.Compression { | ||
case MYSQL_COMPRESS_ZLIB: | ||
w, err = zlib.NewWriterLevel(&payload, zlib.HuffmanOnly) | ||
case MYSQL_COMPRESS_ZSTD: | ||
w = zstd.NewWriter(&payload) | ||
} | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
if len(data) > minCompressLength { | ||
uncompressedLength = len(data) | ||
n, err = w.Write(data) | ||
if err != nil { | ||
return 0, err | ||
} | ||
err = w.Close() | ||
if err != nil { | ||
return 0, err | ||
} | ||
} | ||
|
||
if len(data) > minCompressLength { | ||
compressedLength = len(payload.Bytes()) | ||
} else { | ||
compressedLength = len(data) | ||
} | ||
|
||
c.CompressedSequence = 0 | ||
compressedHeader[0] = byte(compressedLength) | ||
compressedHeader[1] = byte(compressedLength >> 8) | ||
compressedHeader[2] = byte(compressedLength >> 16) | ||
compressedHeader[3] = c.CompressedSequence | ||
compressedHeader[4] = byte(uncompressedLength) | ||
compressedHeader[5] = byte(uncompressedLength >> 8) | ||
compressedHeader[6] = byte(uncompressedLength >> 16) | ||
_, err = compressedPacket.Write(compressedHeader) | ||
if err != nil { | ||
return 0, err | ||
} | ||
c.CompressedSequence++ | ||
|
||
if len(data) > minCompressLength { | ||
_, err = compressedPacket.Write(payload.Bytes()) | ||
} else { | ||
c.Sequence++ | ||
return nil | ||
n, err = compressedPacket.Write(data) | ||
} | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
_, err = c.Write(compressedPacket.Bytes()) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
return n, nil | ||
} | ||
|
||
// WriteClearAuthPacket: Client clear text authentication packet | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should notify the user about their priority at the comment of
SetCapability