Skip to content

Commit 007f306

Browse files
dvilaverdedvilaverdelance6716
authored
fixing bad connection error when reading large compressed packets (#863)
* fixing bad connection error when reading large compressed packets * fixing linting errors * minor cleanup and some more comments * minor cleanup and some more comments * fixing issue when net_buffer_length=1024 * fixing packet reader lookup condition * handle possible nil access violation when attempting to read next compressed packet * removed deprecated linters that no longer exist in golangci-lint 1.58.0 * addressing PR feedback * addressing PR feedback * removed compressedReaderActive --------- Co-authored-by: dvilaverde <[email protected]> Co-authored-by: lance6716 <[email protected]>
1 parent 0ad0d03 commit 007f306

File tree

1 file changed

+83
-42
lines changed

1 file changed

+83
-42
lines changed

packet/conn.go

+83-42
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"crypto/sha1"
1010
"crypto/x509"
1111
"encoding/pem"
12+
goErrors "errors"
1213
"io"
1314
"net"
1415
"sync"
@@ -65,8 +66,6 @@ type Conn struct {
6566

6667
compressedHeader [7]byte
6768

68-
compressedReaderActive bool
69-
7069
compressedReader io.Reader
7170
}
7271

@@ -107,42 +106,17 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
107106
}()
108107

109108
if c.Compression != MYSQL_COMPRESS_NONE {
110-
if !c.compressedReaderActive {
111-
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
112-
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
113-
}
114-
115-
compressedSequence := c.compressedHeader[3]
116-
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
117-
if compressedSequence != c.CompressedSequence {
118-
return nil, errors.Errorf("invalid compressed sequence %d != %d",
119-
compressedSequence, c.CompressedSequence)
120-
}
121-
122-
if uncompressedLength > 0 {
123-
var err error
124-
switch c.Compression {
125-
case MYSQL_COMPRESS_ZLIB:
126-
c.compressedReader, err = zlib.NewReader(c.reader)
127-
case MYSQL_COMPRESS_ZSTD:
128-
c.compressedReader, err = zstd.NewReader(c.reader)
129-
}
130-
if err != nil {
131-
return nil, err
132-
}
109+
if c.compressedReader == nil {
110+
var err error
111+
c.compressedReader, err = c.newCompressedPacketReader()
112+
if err != nil {
113+
return nil, err
133114
}
134-
c.compressedReaderActive = true
135115
}
136116
}
137117

138-
if c.compressedReader != nil {
139-
if err := c.ReadPacketTo(buf, c.compressedReader); err != nil {
140-
return nil, errors.Trace(err)
141-
}
142-
} else {
143-
if err := c.ReadPacketTo(buf, c.reader); err != nil {
144-
return nil, errors.Trace(err)
145-
}
118+
if err := c.ReadPacketTo(buf); err != nil {
119+
return nil, errors.Trace(err)
146120
}
147121

148122
readBytes := buf.Bytes()
@@ -167,22 +141,78 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
167141
return result, nil
168142
}
169143

170-
func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err error) {
144+
// newCompressedPacketReader creates a new compressed packet reader.
145+
func (c *Conn) newCompressedPacketReader() (io.Reader, error) {
146+
if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil {
147+
return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err)
148+
}
149+
150+
compressedSequence := c.compressedHeader[3]
151+
if compressedSequence != c.CompressedSequence {
152+
return nil, errors.Errorf("invalid compressed sequence %d != %d",
153+
compressedSequence, c.CompressedSequence)
154+
}
155+
156+
compressedLength := int(uint32(c.compressedHeader[0]) | uint32(c.compressedHeader[1])<<8 | uint32(c.compressedHeader[2])<<16)
157+
uncompressedLength := int(uint32(c.compressedHeader[4]) | uint32(c.compressedHeader[5])<<8 | uint32(c.compressedHeader[6])<<16)
158+
if uncompressedLength > 0 {
159+
limitedReader := io.LimitReader(c.reader, int64(compressedLength))
160+
switch c.Compression {
161+
case MYSQL_COMPRESS_ZLIB:
162+
return zlib.NewReader(limitedReader)
163+
case MYSQL_COMPRESS_ZSTD:
164+
return zstd.NewReader(limitedReader)
165+
}
166+
}
167+
168+
return nil, nil
169+
}
170+
171+
func (c *Conn) currentPacketReader() io.Reader {
172+
if c.Compression == MYSQL_COMPRESS_NONE || c.compressedReader == nil {
173+
return c.reader
174+
} else {
175+
return c.compressedReader
176+
}
177+
}
178+
179+
func (c *Conn) copyN(dst io.Writer, n int64) (int64, error) {
180+
var written int64
181+
171182
for n > 0 {
172183
bcap := cap(c.copyNBuf)
173184
if int64(bcap) > n {
174185
bcap = int(n)
175186
}
176187
buf := c.copyNBuf[:bcap]
177188

178-
rd, err := io.ReadAtLeast(src, buf, bcap)
189+
// Call ReadAtLeast with the currentPacketReader as it may change on every iteration
190+
// of this loop.
191+
rd, err := io.ReadAtLeast(c.currentPacketReader(), buf, bcap)
192+
179193
n -= int64(rd)
180194

195+
// ReadAtLeast will return EOF or ErrUnexpectedEOF when fewer than the min
196+
// bytes are read. In this case, and when we have compression then advance
197+
// the sequence number and reset the compressed reader to continue reading
198+
// the remaining bytes in the next compressed packet.
199+
if c.Compression != MYSQL_COMPRESS_NONE &&
200+
(goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF)) {
201+
// we have read to EOF and read an incomplete uncompressed packet
202+
// so advance the compressed sequence number and reset the compressed reader
203+
// to get the remaining unread uncompressed bytes from the next compressed packet.
204+
c.CompressedSequence++
205+
if c.compressedReader, err = c.newCompressedPacketReader(); err != nil {
206+
return written, errors.Trace(err)
207+
}
208+
}
209+
181210
if err != nil {
182211
return written, errors.Trace(err)
183212
}
184213

185-
wr, err := dst.Write(buf)
214+
// careful to only write from the buffer the number of bytes read
215+
wr, err := dst.Write(buf[:rd])
186216
written += int64(wr)
187217
if err != nil {
188218
return written, errors.Trace(err)
@@ -192,9 +222,21 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
192222
return written, nil
193223
}
194224

195-
func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
196-
if _, err := io.ReadFull(r, c.header[:4]); err != nil {
225+
func (c *Conn) ReadPacketTo(w io.Writer) error {
226+
b := utils.BytesBufferGet()
227+
defer func() {
228+
utils.BytesBufferPut(b)
229+
}()
230+
231+
// packets that come in a compressed packet may be partial
232+
// so use the copyN function to read the packet header into a
233+
// buffer, since copyN is capable of getting the next compressed
234+
// packet and updating the Conn state with a new compressedReader.
235+
if _, err := c.copyN(b, 4); err != nil {
197236
return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err)
237+
} else {
238+
// copy was successful so copy the 4 bytes from the buffer to the header
239+
copy(c.header[:4], b.Bytes()[:4])
198240
}
199241

200242
length := int(uint32(c.header[0]) | uint32(c.header[1])<<8 | uint32(c.header[2])<<16)
@@ -211,7 +253,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
211253
buf.Grow(length)
212254
}
213255

214-
if n, err := c.copyN(w, r, int64(length)); err != nil {
256+
if n, err := c.copyN(w, int64(length)); err != nil {
215257
return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length)
216258
} else if n != int64(length) {
217259
return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length)
@@ -220,7 +262,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
220262
return nil
221263
}
222264

223-
if err = c.ReadPacketTo(w, r); err != nil {
265+
if err = c.ReadPacketTo(w); err != nil {
224266
return errors.Wrap(err, "ReadPacketTo failed")
225267
}
226268
}
@@ -270,7 +312,6 @@ func (c *Conn) WritePacket(data []byte) error {
270312
return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data))
271313
}
272314
c.compressedReader = nil
273-
c.compressedReaderActive = false
274315
default:
275316
return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set")
276317
}

0 commit comments

Comments
 (0)