Skip to content

Commit f854680

Browse files
authoredApr 12, 2023
Merge pull request #781 from go-mysql-org/mysql8-auth
mysql8 auth compatibility
2 parents e928039 + 7df3e00 commit f854680

File tree

3 files changed

+193
-27
lines changed

3 files changed

+193
-27
lines changed
 

‎client/auth.go

+51-23
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ func authPluginAllowed(pluginName string) bool {
2626
return false
2727
}
2828

29-
// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
29+
// See:
30+
// - https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
31+
// - https://github.com/alibaba/canal/blob/0ec46991499a22870dde4ae736b2586cbcbfea94/driver/src/main/java/com/alibaba/otter/canal/parse/driver/mysql/packets/server/HandshakeInitializationPacket.java#L89
32+
// - https://github.com/vapor/mysql-nio/blob/main/Sources/MySQLNIO/Protocol/MySQLProtocol%2BHandshakeV10.swift
33+
// - https://github.com/github/vitess-gh/blob/70ae1a2b3a116ff6411b0f40852d6e71382f6e07/go/mysql/client.go
3034
func (c *Conn) readInitialHandshake() error {
3135
data, err := c.ReadPacket()
3236
if err != nil {
@@ -40,24 +44,28 @@ func (c *Conn) readInitialHandshake() error {
4044
if data[0] < MinProtocolVersion {
4145
return errors.Errorf("invalid protocol version %d, must >= 10", data[0])
4246
}
47+
pos := 1
4348

4449
// skip mysql version
4550
// mysql version end with 0x00
46-
version := data[1 : bytes.IndexByte(data[1:], 0x00)+1]
51+
version := data[pos : bytes.IndexByte(data[pos:], 0x00)+1]
4752
c.serverVersion = string(version)
48-
pos := 1 + len(version)
53+
pos += len(version) + 1 /*trailing zero byte*/
4954

5055
// connection id length is 4
5156
c.connectionID = binary.LittleEndian.Uint32(data[pos : pos+4])
5257
pos += 4
5358

54-
c.salt = []byte{}
55-
c.salt = append(c.salt, data[pos:pos+8]...)
59+
// first 8 bytes of the plugin provided data (scramble)
60+
c.salt = append(c.salt[:0], data[pos:pos+8]...)
61+
pos += 8
5662

57-
// skip filter
58-
pos += 8 + 1
63+
if data[pos] != 0 { // 0x00 byte, terminating the first part of a scramble
64+
return errors.Errorf("expect 0x00 after scramble, got %q", rune(data[pos]))
65+
}
66+
pos++
5967

60-
// capability lower 2 bytes
68+
// The lower 2 bytes of the Capabilities Flags
6169
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
6270
// check protocol
6371
if c.capability&CLIENT_PROTOCOL_41 == 0 {
@@ -69,35 +77,55 @@ func (c *Conn) readInitialHandshake() error {
6977
pos += 2
7078

7179
if len(data) > pos {
72-
// skip server charset
80+
// default server a_protocol_character_set, only the lower 8-bits
7381
// c.charset = data[pos]
7482
pos += 1
7583

7684
c.status = binary.LittleEndian.Uint16(data[pos : pos+2])
7785
pos += 2
78-
// capability flags (upper 2 bytes)
86+
87+
// The upper 2 bytes of the Capabilities Flags
7988
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
8089
pos += 2
8190

82-
// auth_data is end with 0x00, min data length is 13 + 8 = 21
83-
// ref to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
84-
maxAuthDataLen := 21
85-
if c.capability&CLIENT_PLUGIN_AUTH != 0 && int(data[pos]) > maxAuthDataLen {
86-
maxAuthDataLen = int(data[pos])
91+
// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
92+
authPluginDataLen := data[pos]
93+
if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
94+
return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen)
8795
}
96+
pos++
8897

89-
// skip reserved (all [00])
90-
pos += 10 + 1
98+
// skip reserved (all [00] ?)
99+
pos += 10
91100

92-
// auth_data is end with 0x00, so we need to trim 0x00
93-
resetOfAuthDataEndPos := pos + maxAuthDataLen - 8 - 1
94-
c.salt = append(c.salt, data[pos:resetOfAuthDataEndPos]...)
101+
if c.capability&CLIENT_SECURE_CONNECTION != 0 {
102+
// Rest of the plugin provided data (scramble)
95103

96-
// skip reset of end pos
97-
pos = resetOfAuthDataEndPos + 1
104+
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
105+
// $len=MAX(13, length of auth-plugin-data - 8)
106+
//
107+
// https://github.com/mysql/mysql-server/blob/1bfe02bdad6604d54913c62614bde57a055c8332/sql/auth/sql_authentication.cc#L1641-L1642
108+
// the first packet *must* have at least 20 bytes of a scramble.
109+
// if a plugin provided less, we pad it to 20 with zeros
110+
rest := int(authPluginDataLen) - 8
111+
if max := 12 + 1; rest < max {
112+
rest = max
113+
}
114+
115+
authPluginDataPart2 := data[pos : pos+rest]
116+
pos += rest
117+
118+
c.salt = append(c.salt, authPluginDataPart2...)
119+
}
98120

99121
if c.capability&CLIENT_PLUGIN_AUTH != 0 {
100-
c.authPluginName = string(data[pos : len(data)-1])
122+
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
123+
pos += len(c.authPluginName)
124+
125+
if data[pos] != 0 {
126+
return errors.Errorf("expect 0x00 after authPluginName, got %q", rune(data[pos]))
127+
}
128+
// pos++ // ineffectual
101129
}
102130
}
103131

‎client/conn.go

+129-3
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,18 @@ func (c *Conn) handshake() error {
119119
var err error
120120
if err = c.readInitialHandshake(); err != nil {
121121
c.Close()
122-
return errors.Trace(err)
122+
return errors.Trace(fmt.Errorf("readInitialHandshake: %w", err))
123123
}
124124

125125
if err := c.writeAuthHandshake(); err != nil {
126126
c.Close()
127127

128-
return errors.Trace(err)
128+
return errors.Trace(fmt.Errorf("writeAuthHandshake: %w", err))
129129
}
130130

131131
if err := c.handleAuthResult(); err != nil {
132132
c.Close()
133-
return errors.Trace(err)
133+
return errors.Trace(fmt.Errorf("handleAuthResult: %w", err))
134134
}
135135

136136
return nil
@@ -408,3 +408,129 @@ func (c *Conn) exec(query string) (*Result, error) {
408408

409409
return c.readResult(false)
410410
}
411+
412+
func (c *Conn) CapabilityString() string {
413+
var caps []string
414+
capability := c.capability
415+
for i := 0; capability != 0; i++ {
416+
field := uint32(1 << i)
417+
if capability&field == 0 {
418+
continue
419+
}
420+
capability ^= field
421+
422+
switch field {
423+
case CLIENT_LONG_PASSWORD:
424+
caps = append(caps, "CLIENT_LONG_PASSWORD")
425+
case CLIENT_FOUND_ROWS:
426+
caps = append(caps, "CLIENT_FOUND_ROWS")
427+
case CLIENT_LONG_FLAG:
428+
caps = append(caps, "CLIENT_LONG_FLAG")
429+
case CLIENT_CONNECT_WITH_DB:
430+
caps = append(caps, "CLIENT_CONNECT_WITH_DB")
431+
case CLIENT_NO_SCHEMA:
432+
caps = append(caps, "CLIENT_NO_SCHEMA")
433+
case CLIENT_COMPRESS:
434+
caps = append(caps, "CLIENT_COMPRESS")
435+
case CLIENT_ODBC:
436+
caps = append(caps, "CLIENT_ODBC")
437+
case CLIENT_LOCAL_FILES:
438+
caps = append(caps, "CLIENT_LOCAL_FILES")
439+
case CLIENT_IGNORE_SPACE:
440+
caps = append(caps, "CLIENT_IGNORE_SPACE")
441+
case CLIENT_PROTOCOL_41:
442+
caps = append(caps, "CLIENT_PROTOCOL_41")
443+
case CLIENT_INTERACTIVE:
444+
caps = append(caps, "CLIENT_INTERACTIVE")
445+
case CLIENT_SSL:
446+
caps = append(caps, "CLIENT_SSL")
447+
case CLIENT_IGNORE_SIGPIPE:
448+
caps = append(caps, "CLIENT_IGNORE_SIGPIPE")
449+
case CLIENT_TRANSACTIONS:
450+
caps = append(caps, "CLIENT_TRANSACTIONS")
451+
case CLIENT_RESERVED:
452+
caps = append(caps, "CLIENT_RESERVED")
453+
case CLIENT_SECURE_CONNECTION:
454+
caps = append(caps, "CLIENT_SECURE_CONNECTION")
455+
case CLIENT_MULTI_STATEMENTS:
456+
caps = append(caps, "CLIENT_MULTI_STATEMENTS")
457+
case CLIENT_MULTI_RESULTS:
458+
caps = append(caps, "CLIENT_MULTI_RESULTS")
459+
case CLIENT_PS_MULTI_RESULTS:
460+
caps = append(caps, "CLIENT_PS_MULTI_RESULTS")
461+
case CLIENT_PLUGIN_AUTH:
462+
caps = append(caps, "CLIENT_PLUGIN_AUTH")
463+
case CLIENT_CONNECT_ATTRS:
464+
caps = append(caps, "CLIENT_CONNECT_ATTRS")
465+
case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA:
466+
caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA")
467+
case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS:
468+
caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS")
469+
case CLIENT_SESSION_TRACK:
470+
caps = append(caps, "CLIENT_SESSION_TRACK")
471+
case CLIENT_DEPRECATE_EOF:
472+
caps = append(caps, "CLIENT_DEPRECATE_EOF")
473+
case CLIENT_OPTIONAL_RESULTSET_METADATA:
474+
caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA")
475+
case CLIENT_ZSTD_COMPRESSION_ALGORITHM:
476+
caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM")
477+
case CLIENT_QUERY_ATTRIBUTES:
478+
caps = append(caps, "CLIENT_QUERY_ATTRIBUTES")
479+
case MULTI_FACTOR_AUTHENTICATION:
480+
caps = append(caps, "MULTI_FACTOR_AUTHENTICATION")
481+
case CLIENT_CAPABILITY_EXTENSION:
482+
caps = append(caps, "CLIENT_CAPABILITY_EXTENSION")
483+
case CLIENT_SSL_VERIFY_SERVER_CERT:
484+
caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT")
485+
case CLIENT_REMEMBER_OPTIONS:
486+
caps = append(caps, "CLIENT_REMEMBER_OPTIONS")
487+
default:
488+
caps = append(caps, fmt.Sprintf("(%d)", field))
489+
}
490+
}
491+
492+
return strings.Join(caps, "|")
493+
}
494+
495+
func (c *Conn) StatusString() string {
496+
var stats []string
497+
status := c.status
498+
for i := 0; status != 0; i++ {
499+
field := uint16(1 << i)
500+
if status&field == 0 {
501+
continue
502+
}
503+
status ^= field
504+
505+
switch field {
506+
case SERVER_STATUS_IN_TRANS:
507+
stats = append(stats, "SERVER_STATUS_IN_TRANS")
508+
case SERVER_STATUS_AUTOCOMMIT:
509+
stats = append(stats, "SERVER_STATUS_AUTOCOMMIT")
510+
case SERVER_MORE_RESULTS_EXISTS:
511+
stats = append(stats, "SERVER_MORE_RESULTS_EXISTS")
512+
case SERVER_STATUS_NO_GOOD_INDEX_USED:
513+
stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED")
514+
case SERVER_STATUS_NO_INDEX_USED:
515+
stats = append(stats, "SERVER_STATUS_NO_INDEX_USED")
516+
case SERVER_STATUS_CURSOR_EXISTS:
517+
stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS")
518+
case SERVER_STATUS_LAST_ROW_SEND:
519+
stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND")
520+
case SERVER_STATUS_DB_DROPPED:
521+
stats = append(stats, "SERVER_STATUS_DB_DROPPED")
522+
case SERVER_STATUS_NO_BACKSLASH_ESCAPED:
523+
stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED")
524+
case SERVER_STATUS_METADATA_CHANGED:
525+
stats = append(stats, "SERVER_STATUS_METADATA_CHANGED")
526+
case SERVER_QUERY_WAS_SLOW:
527+
stats = append(stats, "SERVER_QUERY_WAS_SLOW")
528+
case SERVER_PS_OUT_PARAMS:
529+
stats = append(stats, "SERVER_PS_OUT_PARAMS")
530+
default:
531+
stats = append(stats, fmt.Sprintf("(%d)", field))
532+
}
533+
}
534+
535+
return strings.Join(stats, "|")
536+
}

‎mysql/const.go

+13-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ const (
7676
)
7777

7878
const (
79+
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__capabilities__flags.html
80+
7981
CLIENT_LONG_PASSWORD uint32 = 1 << iota
8082
CLIENT_FOUND_ROWS
8183
CLIENT_LONG_FLAG
@@ -98,6 +100,16 @@ const (
98100
CLIENT_PLUGIN_AUTH
99101
CLIENT_CONNECT_ATTRS
100102
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
103+
CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS
104+
CLIENT_SESSION_TRACK
105+
CLIENT_DEPRECATE_EOF
106+
CLIENT_OPTIONAL_RESULTSET_METADATA
107+
CLIENT_ZSTD_COMPRESSION_ALGORITHM
108+
CLIENT_QUERY_ATTRIBUTES
109+
MULTI_FACTOR_AUTHENTICATION
110+
CLIENT_CAPABILITY_EXTENSION
111+
CLIENT_SSL_VERIFY_SERVER_CERT
112+
CLIENT_REMEMBER_OPTIONS
101113
)
102114

103115
const (
@@ -119,7 +131,7 @@ const (
119131
MYSQL_TYPE_VARCHAR
120132
MYSQL_TYPE_BIT
121133

122-
//mysql 5.6
134+
// mysql 5.6
123135
MYSQL_TYPE_TIMESTAMP2
124136
MYSQL_TYPE_DATETIME2
125137
MYSQL_TYPE_TIME2

0 commit comments

Comments
 (0)
Please sign in to comment.