9
9
"crypto/sha1"
10
10
"crypto/x509"
11
11
"encoding/pem"
12
+ goErrors "errors"
12
13
"io"
13
14
"net"
14
15
"sync"
@@ -65,8 +66,6 @@ type Conn struct {
65
66
66
67
compressedHeader [7 ]byte
67
68
68
- compressedReaderActive bool
69
-
70
69
compressedReader io.Reader
71
70
}
72
71
@@ -107,42 +106,17 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
107
106
}()
108
107
109
108
if c .Compression != MYSQL_COMPRESS_NONE {
110
- if ! c .compressedReaderActive {
111
- if _ , err := io .ReadFull (c .reader , c .compressedHeader [:7 ]); err != nil {
112
- return nil , errors .Wrapf (ErrBadConn , "io.ReadFull(compressedHeader) failed. err %v" , err )
113
- }
114
-
115
- compressedSequence := c .compressedHeader [3 ]
116
- uncompressedLength := int (uint32 (c .compressedHeader [4 ]) | uint32 (c .compressedHeader [5 ])<< 8 | uint32 (c .compressedHeader [6 ])<< 16 )
117
- if compressedSequence != c .CompressedSequence {
118
- return nil , errors .Errorf ("invalid compressed sequence %d != %d" ,
119
- compressedSequence , c .CompressedSequence )
120
- }
121
-
122
- if uncompressedLength > 0 {
123
- var err error
124
- switch c .Compression {
125
- case MYSQL_COMPRESS_ZLIB :
126
- c .compressedReader , err = zlib .NewReader (c .reader )
127
- case MYSQL_COMPRESS_ZSTD :
128
- c .compressedReader , err = zstd .NewReader (c .reader )
129
- }
130
- if err != nil {
131
- return nil , err
132
- }
109
+ if c .compressedReader == nil {
110
+ var err error
111
+ c .compressedReader , err = c .newCompressedPacketReader ()
112
+ if err != nil {
113
+ return nil , err
133
114
}
134
- c .compressedReaderActive = true
135
115
}
136
116
}
137
117
138
- if c .compressedReader != nil {
139
- if err := c .ReadPacketTo (buf , c .compressedReader ); err != nil {
140
- return nil , errors .Trace (err )
141
- }
142
- } else {
143
- if err := c .ReadPacketTo (buf , c .reader ); err != nil {
144
- return nil , errors .Trace (err )
145
- }
118
+ if err := c .ReadPacketTo (buf ); err != nil {
119
+ return nil , errors .Trace (err )
146
120
}
147
121
148
122
readBytes := buf .Bytes ()
@@ -167,22 +141,78 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) {
167
141
return result , nil
168
142
}
169
143
170
- func (c * Conn ) copyN (dst io.Writer , src io.Reader , n int64 ) (written int64 , err error ) {
144
+ // newCompressedPacketReader creates a new compressed packet reader.
145
+ func (c * Conn ) newCompressedPacketReader () (io.Reader , error ) {
146
+ if _ , err := io .ReadFull (c .reader , c .compressedHeader [:7 ]); err != nil {
147
+ return nil , errors .Wrapf (ErrBadConn , "io.ReadFull(compressedHeader) failed. err %v" , err )
148
+ }
149
+
150
+ compressedSequence := c .compressedHeader [3 ]
151
+ if compressedSequence != c .CompressedSequence {
152
+ return nil , errors .Errorf ("invalid compressed sequence %d != %d" ,
153
+ compressedSequence , c .CompressedSequence )
154
+ }
155
+
156
+ compressedLength := int (uint32 (c .compressedHeader [0 ]) | uint32 (c .compressedHeader [1 ])<< 8 | uint32 (c .compressedHeader [2 ])<< 16 )
157
+ uncompressedLength := int (uint32 (c .compressedHeader [4 ]) | uint32 (c .compressedHeader [5 ])<< 8 | uint32 (c .compressedHeader [6 ])<< 16 )
158
+ if uncompressedLength > 0 {
159
+ limitedReader := io .LimitReader (c .reader , int64 (compressedLength ))
160
+ switch c .Compression {
161
+ case MYSQL_COMPRESS_ZLIB :
162
+ return zlib .NewReader (limitedReader )
163
+ case MYSQL_COMPRESS_ZSTD :
164
+ return zstd .NewReader (limitedReader )
165
+ }
166
+ }
167
+
168
+ return nil , nil
169
+ }
170
+
171
+ func (c * Conn ) currentPacketReader () io.Reader {
172
+ if c .Compression == MYSQL_COMPRESS_NONE || c .compressedReader == nil {
173
+ return c .reader
174
+ } else {
175
+ return c .compressedReader
176
+ }
177
+ }
178
+
179
+ func (c * Conn ) copyN (dst io.Writer , n int64 ) (int64 , error ) {
180
+ var written int64
181
+
171
182
for n > 0 {
172
183
bcap := cap (c .copyNBuf )
173
184
if int64 (bcap ) > n {
174
185
bcap = int (n )
175
186
}
176
187
buf := c .copyNBuf [:bcap ]
177
188
178
- rd , err := io .ReadAtLeast (src , buf , bcap )
189
+ // Call ReadAtLeast with the currentPacketReader as it may change on every iteration
190
+ // of this loop.
191
+ rd , err := io .ReadAtLeast (c .currentPacketReader (), buf , bcap )
192
+
179
193
n -= int64 (rd )
180
194
195
+ // ReadAtLeast will return EOF or ErrUnexpectedEOF when fewer than the min
196
+ // bytes are read. In this case, and when we have compression then advance
197
+ // the sequence number and reset the compressed reader to continue reading
198
+ // the remaining bytes in the next compressed packet.
199
+ if c .Compression != MYSQL_COMPRESS_NONE &&
200
+ (goErrors .Is (err , io .ErrUnexpectedEOF ) || goErrors .Is (err , io .EOF )) {
201
+ // we have read to EOF and read an incomplete uncompressed packet
202
+ // so advance the compressed sequence number and reset the compressed reader
203
+ // to get the remaining unread uncompressed bytes from the next compressed packet.
204
+ c .CompressedSequence ++
205
+ if c .compressedReader , err = c .newCompressedPacketReader (); err != nil {
206
+ return written , errors .Trace (err )
207
+ }
208
+ }
209
+
181
210
if err != nil {
182
211
return written , errors .Trace (err )
183
212
}
184
213
185
- wr , err := dst .Write (buf )
214
+ // careful to only write from the buffer the number of bytes read
215
+ wr , err := dst .Write (buf [:rd ])
186
216
written += int64 (wr )
187
217
if err != nil {
188
218
return written , errors .Trace (err )
@@ -192,9 +222,21 @@ func (c *Conn) copyN(dst io.Writer, src io.Reader, n int64) (written int64, err
192
222
return written , nil
193
223
}
194
224
195
- func (c * Conn ) ReadPacketTo (w io.Writer , r io.Reader ) error {
196
- if _ , err := io .ReadFull (r , c .header [:4 ]); err != nil {
225
+ func (c * Conn ) ReadPacketTo (w io.Writer ) error {
226
+ b := utils .BytesBufferGet ()
227
+ defer func () {
228
+ utils .BytesBufferPut (b )
229
+ }()
230
+
231
+ // packets that come in a compressed packet may be partial
232
+ // so use the copyN function to read the packet header into a
233
+ // buffer, since copyN is capable of getting the next compressed
234
+ // packet and updating the Conn state with a new compressedReader.
235
+ if _ , err := c .copyN (b , 4 ); err != nil {
197
236
return errors .Wrapf (ErrBadConn , "io.ReadFull(header) failed. err %v" , err )
237
+ } else {
238
+ // copy was successful so copy the 4 bytes from the buffer to the header
239
+ copy (c .header [:4 ], b .Bytes ()[:4 ])
198
240
}
199
241
200
242
length := int (uint32 (c .header [0 ]) | uint32 (c .header [1 ])<< 8 | uint32 (c .header [2 ])<< 16 )
@@ -211,7 +253,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
211
253
buf .Grow (length )
212
254
}
213
255
214
- if n , err := c .copyN (w , r , int64 (length )); err != nil {
256
+ if n , err := c .copyN (w , int64 (length )); err != nil {
215
257
return errors .Wrapf (ErrBadConn , "io.CopyN failed. err %v, copied %v, expected %v" , err , n , length )
216
258
} else if n != int64 (length ) {
217
259
return errors .Wrapf (ErrBadConn , "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected" , n , length )
@@ -220,7 +262,7 @@ func (c *Conn) ReadPacketTo(w io.Writer, r io.Reader) error {
220
262
return nil
221
263
}
222
264
223
- if err = c .ReadPacketTo (w , r ); err != nil {
265
+ if err = c .ReadPacketTo (w ); err != nil {
224
266
return errors .Wrap (err , "ReadPacketTo failed" )
225
267
}
226
268
}
@@ -270,7 +312,6 @@ func (c *Conn) WritePacket(data []byte) error {
270
312
return errors .Wrapf (ErrBadConn , "Write failed. only %v bytes written, while %v expected" , n , len (data ))
271
313
}
272
314
c .compressedReader = nil
273
- c .compressedReaderActive = false
274
315
default :
275
316
return errors .Wrapf (ErrBadConn , "Write failed. Unsuppored compression algorithm set" )
276
317
}
0 commit comments