Skip to content

Commit 82c1282

Browse files
committed
Merge branch 'master' into PARTIAL_UPDATE_ROWS_EVENT
# Conflicts: # .github/workflows/ci.yml
2 parents 1ca1b07 + 0e9bf02 commit 82c1282

File tree

4 files changed

+186
-27
lines changed

4 files changed

+186
-27
lines changed

.github/workflows/ci.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jobs:
66
strategy:
77
matrix:
88
go: [ 1.19, 1.18, 1.17, 1.16 ]
9-
os: [ ubuntu-20.04 ]
9+
os: [ ubuntu-22.04, ubuntu-20.04 ]
1010
name: Tests Go ${{ matrix.go }} # This name is used in main branch protection rules
1111
runs-on: ${{ matrix.os }}
1212

@@ -48,3 +48,4 @@ jobs:
4848
uses: golangci/golangci-lint-action@v2
4949
with:
5050
version: latest
51+
args: --timeout=3m

client/auth.go

+42-22
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func authPluginAllowed(pluginName string) bool {
2626
return false
2727
}
2828

29-
// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
29+
// See: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
3030
func (c *Conn) readInitialHandshake() error {
3131
data, err := c.ReadPacket()
3232
if err != nil {
@@ -40,24 +40,28 @@ func (c *Conn) readInitialHandshake() error {
4040
if data[0] < MinProtocolVersion {
4141
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
4242
}
43+
pos := 1
4344

4445
// skip mysql version
4546
// mysql version end with 0x00
46-
version := data[1 : bytes.IndexByte(data[1:], 0x00)+1]
47+
version := data[pos : bytes.IndexByte(data[pos:], 0x00)+1]
4748
c.serverVersion = string(version)
48-
pos := 1 + len(version)
49+
pos += len(version) + 1 /*trailing zero byte*/
4950

5051
// connection id length is 4
5152
c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
5253
pos += 4
5354

54-
c.salt = []byte{}
55-
c.salt = append(c.salt, data[pos:pos+8]...)
55+
// first 8 bytes of the plugin provided data (scramble)
56+
c.salt = append(c.salt[:0], data[pos:pos+8]...)
57+
pos += 8
5658

57-
// skip filter
58-
pos += 8 + 1
59+
if data[pos] != 0 { // 0x00 byte, terminating the first part of a scramble
60+
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
61+
}
62+
pos++
5963

60-
// capability lower 2 bytes
64+
// The lower 2 bytes of the Capabilities Flags
6165
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
6266
// check protocol
6367
if c.capability&CLIENT_PROTOCOL_41 == 0 {
@@ -69,35 +73,51 @@ func (c *Conn) readInitialHandshake() error {
6973
pos += 2
7074

7175
if len(data) > pos {
72-
// skip server charset
76+
// default server a_protocol_character_set, only the lower 8-bits
7377
// c.charset = data[pos]
7478
pos += 1
7579

7680
c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
7781
pos += 2
78-
// capability flags (upper 2 bytes)
82+
83+
// The upper 2 bytes of the Capabilities Flags
7984
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
8085
pos += 2
8186

82-
// auth_data is end with 0x00, min data length is 13 + 8 = 21
83-
// ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
84-
maxAuthDataLen := 21
85-
if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen {
86-
maxAuthDataLen = int(data[pos])
87+
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
88+
authPluginDataLen := data[pos]
89+
if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
90+
return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen)
8791
}
92+
pos++
8893

8994
// skip reserved (all [00])
90-
pos += 10 + 1
95+
pos += 6
9196

92-
// auth_data is end with 0x00, so we need to trim 0x00
93-
resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1
94-
c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...)
97+
// https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift
98+
if c.capability&CLIENT_LONG_PASSWORD != 0 {
99+
// skip reserved (all [00])
100+
pos += 4
101+
} else {
102+
// unknown
103+
pos += 4
104+
}
105+
106+
if rest := int(authPluginDataLen) - 8; rest > 0 {
107+
authPluginDataPart2 := data[pos : pos+rest]
108+
pos += rest
95109

96-
// skip reset of end pos
97-
pos = resetOfAuthDataEndPos + 1
110+
c.salt = append(c.salt, authPluginDataPart2...)
111+
}
98112

99113
if c.capability&CLIENT_PLUGIN_AUTH != 0 {
100-
c.authPluginName = string(data[pos : len(data)-1])
114+
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
115+
pos += len(c.authPluginName)
116+
117+
if data[pos] != 0 {
118+
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
119+
}
120+
pos++
101121
}
102122
}
103123

client/conn.go

+129-3
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,18 @@ func (c *Conn) handshake() error {
119119
var err error
120120
if err = c.readInitialHandshake(); err != nil {
121121
c.Close()
122-
return errors.Trace(err)
122+
return errors.Trace(fmt.Errorf("readInitialHandshake: %w", err))
123123
}
124124

125125
if err := c.writeAuthHandshake(); err != nil {
126126
c.Close()
127127

128-
return errors.Trace(err)
128+
return errors.Trace(fmt.Errorf("writeAuthHandshake: %w", err))
129129
}
130130

131131
if err := c.handleAuthResult(); err != nil {
132132
c.Close()
133-
return errors.Trace(err)
133+
return errors.Trace(fmt.Errorf("handleAuthResult: %w", err))
134134
}
135135

136136
return nil
@@ -408,3 +408,129 @@ func (c *Conn) exec(query string) (*Result, error) {
408408

409409
return c.readResult(false)
410410
}
411+
412+
func (c *Conn) CapabilityString() string {
413+
var caps []string
414+
capability := c.capability
415+
for i := 0; capability != 0; i++ {
416+
field := uint32(1 << i)
417+
if capability&field == 0 {
418+
continue
419+
}
420+
capability ^= field
421+
422+
switch field {
423+
case CLIENT_LONG_PASSWORD:
424+
caps = append(caps, "CLIENT_LONG_PASSWORD")
425+
case CLIENT_FOUND_ROWS:
426+
caps = append(caps, "CLIENT_FOUND_ROWS")
427+
case CLIENT_LONG_FLAG:
428+
caps = append(caps, "CLIENT_LONG_FLAG")
429+
case CLIENT_CONNECT_WITH_DB:
430+
caps = append(caps, "CLIENT_CONNECT_WITH_DB")
431+
case CLIENT_NO_SCHEMA:
432+
caps = append(caps, "CLIENT_NO_SCHEMA")
433+
case CLIENT_COMPRESS:
434+
caps = append(caps, "CLIENT_COMPRESS")
435+
case CLIENT_ODBC:
436+
caps = append(caps, "CLIENT_ODBC")
437+
case CLIENT_LOCAL_FILES:
438+
caps = append(caps, "CLIENT_LOCAL_FILES")
439+
case CLIENT_IGNORE_SPACE:
440+
caps = append(caps, "CLIENT_IGNORE_SPACE")
441+
case CLIENT_PROTOCOL_41:
442+
caps = append(caps, "CLIENT_PROTOCOL_41")
443+
case CLIENT_INTERACTIVE:
444+
caps = append(caps, "CLIENT_INTERACTIVE")
445+
case CLIENT_SSL:
446+
caps = append(caps, "CLIENT_SSL")
447+
case CLIENT_IGNORE_SIGPIPE:
448+
caps = append(caps, "CLIENT_IGNORE_SIGPIPE")
449+
case CLIENT_TRANSACTIONS:
450+
caps = append(caps, "CLIENT_TRANSACTIONS")
451+
case CLIENT_RESERVED:
452+
caps = append(caps, "CLIENT_RESERVED")
453+
case CLIENT_SECURE_CONNECTION:
454+
caps = append(caps, "CLIENT_SECURE_CONNECTION")
455+
case CLIENT_MULTI_STATEMENTS:
456+
caps = append(caps, "CLIENT_MULTI_STATEMENTS")
457+
case CLIENT_MULTI_RESULTS:
458+
caps = append(caps, "CLIENT_MULTI_RESULTS")
459+
case CLIENT_PS_MULTI_RESULTS:
460+
caps = append(caps, "CLIENT_PS_MULTI_RESULTS")
461+
case CLIENT_PLUGIN_AUTH:
462+
caps = append(caps, "CLIENT_PLUGIN_AUTH")
463+
case CLIENT_CONNECT_ATTRS:
464+
caps = append(caps, "CLIENT_CONNECT_ATTRS")
465+
case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA:
466+
caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA")
467+
case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS:
468+
caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS")
469+
case CLIENT_SESSION_TRACK:
470+
caps = append(caps, "CLIENT_SESSION_TRACK")
471+
case CLIENT_DEPRECATE_EOF:
472+
caps = append(caps, "CLIENT_DEPRECATE_EOF")
473+
case CLIENT_OPTIONAL_RESULTSET_METADATA:
474+
caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA")
475+
case CLIENT_ZSTD_COMPRESSION_ALGORITHM:
476+
caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM")
477+
case CLIENT_QUERY_ATTRIBUTES:
478+
caps = append(caps, "CLIENT_QUERY_ATTRIBUTES")
479+
case MULTI_FACTOR_AUTHENTICATION:
480+
caps = append(caps, "MULTI_FACTOR_AUTHENTICATION")
481+
case CLIENT_CAPABILITY_EXTENSION:
482+
caps = append(caps, "CLIENT_CAPABILITY_EXTENSION")
483+
case CLIENT_SSL_VERIFY_SERVER_CERT:
484+
caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT")
485+
case CLIENT_REMEMBER_OPTIONS:
486+
caps = append(caps, "CLIENT_REMEMBER_OPTIONS")
487+
default:
488+
caps = append(caps, fmt.Sprintf("(%d)", field))
489+
}
490+
}
491+
492+
return strings.Join(caps, "|")
493+
}
494+
495+
func (c *Conn) StatusString() string {
496+
var stats []string
497+
status := c.status
498+
for i := 0; status != 0; i++ {
499+
field := uint16(1 << i)
500+
if status&field == 0 {
501+
continue
502+
}
503+
status ^= field
504+
505+
switch field {
506+
case SERVER_STATUS_IN_TRANS:
507+
stats = append(stats, "SERVER_STATUS_IN_TRANS")
508+
case SERVER_STATUS_AUTOCOMMIT:
509+
stats = append(stats, "SERVER_STATUS_AUTOCOMMIT")
510+
case SERVER_MORE_RESULTS_EXISTS:
511+
stats = append(stats, "SERVER_MORE_RESULTS_EXISTS")
512+
case SERVER_STATUS_NO_GOOD_INDEX_USED:
513+
stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED")
514+
case SERVER_STATUS_NO_INDEX_USED:
515+
stats = append(stats, "SERVER_STATUS_NO_INDEX_USED")
516+
case SERVER_STATUS_CURSOR_EXISTS:
517+
stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS")
518+
case SERVER_STATUS_LAST_ROW_SEND:
519+
stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND")
520+
case SERVER_STATUS_DB_DROPPED:
521+
stats = append(stats, "SERVER_STATUS_DB_DROPPED")
522+
case SERVER_STATUS_NO_BACKSLASH_ESCAPED:
523+
stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED")
524+
case SERVER_STATUS_METADATA_CHANGED:
525+
stats = append(stats, "SERVER_STATUS_METADATA_CHANGED")
526+
case SERVER_QUERY_WAS_SLOW:
527+
stats = append(stats, "SERVER_QUERY_WAS_SLOW")
528+
case SERVER_PS_OUT_PARAMS:
529+
stats = append(stats, "SERVER_PS_OUT_PARAMS")
530+
default:
531+
stats = append(stats, fmt.Sprintf("(%d)", field))
532+
}
533+
}
534+
535+
return strings.Join(stats, "|")
536+
}

mysql/const.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ const (
7676
)
7777

7878
const (
79+
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html
80+
7981
CLIENT_LONG_PASSWORD uint32 = 1 << iota
8082
CLIENT_FOUND_ROWS
8183
CLIENT_LONG_FLAG
@@ -98,6 +100,16 @@ const (
98100
CLIENT_PLUGIN_AUTH
99101
CLIENT_CONNECT_ATTRS
100102
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
103+
CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS
104+
CLIENT_SESSION_TRACK
105+
CLIENT_DEPRECATE_EOF
106+
CLIENT_OPTIONAL_RESULTSET_METADATA
107+
CLIENT_ZSTD_COMPRESSION_ALGORITHM
108+
CLIENT_QUERY_ATTRIBUTES
109+
MULTI_FACTOR_AUTHENTICATION
110+
CLIENT_CAPABILITY_EXTENSION
111+
CLIENT_SSL_VERIFY_SERVER_CERT
112+
CLIENT_REMEMBER_OPTIONS
101113
)
102114

103115
const (
@@ -119,7 +131,7 @@ const (
119131
MYSQL_TYPE_VARCHAR
120132
MYSQL_TYPE_BIT
121133

122-
//mysql 5.6
134+
// mysql 5.6
123135
MYSQL_TYPE_TIMESTAMP2
124136
MYSQL_TYPE_DATETIME2
125137
MYSQL_TYPE_TIME2

0 commit comments

Comments
 (0)