Skip to content

Change how we import the mysql package #982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 37 additions & 37 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ import (
"encoding/binary"
"fmt"

. "github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/mysql"
"github.com/go-mysql-org/go-mysql/packet"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/charset"
)

const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
const defaultAuthPluginName = mysql.AUTH_NATIVE_PASSWORD

// defines the supported auth plugins
var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD}
var supportedAuthPlugins = []string{mysql.AUTH_NATIVE_PASSWORD, mysql.AUTH_SHA256_PASSWORD, mysql.AUTH_CACHING_SHA2_PASSWORD}

// helper function to determine what auth methods are allowed by this client
func authPluginAllowed(pluginName string) bool {
Expand All @@ -38,12 +38,12 @@ func (c *Conn) readInitialHandshake() error {
return errors.Trace(err)
}

if data[0] == ERR_HEADER {
if data[0] == mysql.ERR_HEADER {
return errors.Annotate(c.handleErrorPacket(data), "read initial handshake error")
}

if data[0] != ClassicProtocolVersion {
if data[0] == XProtocolVersion {
if data[0] != mysql.ClassicProtocolVersion {
if data[0] == mysql.XProtocolVersion {
return errors.Errorf(
"invalid protocol version %d, expected 10. "+
"This might be X Protocol, make sure to connect to the right port",
Expand Down Expand Up @@ -75,10 +75,10 @@ func (c *Conn) readInitialHandshake() error {
// The lower 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2]))
// check protocol
if c.capability&CLIENT_PROTOCOL_41 == 0 {
if c.capability&mysql.CLIENT_PROTOCOL_41 == 0 {
return errors.New("the MySQL server can not support protocol 41 and above required by the client")
}
if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil {
if c.capability&mysql.CLIENT_SSL == 0 && c.tlsConfig != nil {
return errors.New("the MySQL Server does not support TLS required by the client")
}
pos += 2
Expand All @@ -97,15 +97,15 @@ func (c *Conn) readInitialHandshake() error {

// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
authPluginDataLen := data[pos]
if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
if (c.capability&mysql.CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) {
return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen)
}
pos++

// skip reserved (all [00] ?)
pos += 10

if c.capability&CLIENT_SECURE_CONNECTION != 0 {
if c.capability&mysql.CLIENT_SECURE_CONNECTION != 0 {
// Rest of the plugin provided data (scramble)

// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html
Expand All @@ -125,7 +125,7 @@ func (c *Conn) readInitialHandshake() error {
c.salt = append(c.salt, authPluginDataPart2...)
}

if c.capability&CLIENT_PLUGIN_AUTH != 0 {
if c.capability&mysql.CLIENT_PLUGIN_AUTH != 0 {
c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)])
pos += len(c.authPluginName)

Expand Down Expand Up @@ -153,13 +153,13 @@ func (c *Conn) readInitialHandshake() error {
func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) {
// password hashing
switch c.authPluginName {
case AUTH_NATIVE_PASSWORD:
return CalcPassword(authData[:20], []byte(c.password)), false, nil
case AUTH_CACHING_SHA2_PASSWORD:
return CalcCachingSha2Password(authData, c.password), false, nil
case AUTH_CLEAR_PASSWORD:
case mysql.AUTH_NATIVE_PASSWORD:
return mysql.CalcPassword(authData[:20], []byte(c.password)), false, nil
case mysql.AUTH_CACHING_SHA2_PASSWORD:
return mysql.CalcCachingSha2Password(authData, c.password), false, nil
case mysql.AUTH_CLEAR_PASSWORD:
return []byte(c.password), true, nil
case AUTH_SHA256_PASSWORD:
case mysql.AUTH_SHA256_PASSWORD:
if len(c.password) == 0 {
return nil, true, nil
}
Expand All @@ -186,10 +186,10 @@ func (c *Conn) genAttributes() []byte {

attrData := make([]byte, 0)
for k, v := range c.attributes {
attrData = append(attrData, PutLengthEncodedString([]byte(k))...)
attrData = append(attrData, PutLengthEncodedString([]byte(v))...)
attrData = append(attrData, mysql.PutLengthEncodedString([]byte(k))...)
attrData = append(attrData, mysql.PutLengthEncodedString([]byte(v))...)
}
return append(PutLengthEncodedInt(uint64(len(attrData))), attrData...)
return append(mysql.PutLengthEncodedInt(uint64(len(attrData))), attrData...)
}

// See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
Expand All @@ -199,23 +199,23 @@ func (c *Conn) writeAuthHandshake() error {
}

// Set default client capabilities that reflect the abilities of this library
capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION |
CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH
capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION |
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH
// Adjust client capability flags based on server support
capability |= c.capability & CLIENT_LONG_FLAG
capability |= c.capability & CLIENT_QUERY_ATTRIBUTES
capability |= c.capability & mysql.CLIENT_LONG_FLAG
capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES
// Adjust client capability flags on specific client requests
// Only flags that would make any sense setting and aren't handled elsewhere
// in the library are supported here
capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE |
c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS |
c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS |
c.ccaps&CLIENT_COMPRESS | c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM |
c.ccaps&CLIENT_LOCAL_FILES
capability |= c.ccaps&mysql.CLIENT_FOUND_ROWS | c.ccaps&mysql.CLIENT_IGNORE_SPACE |
c.ccaps&mysql.CLIENT_MULTI_STATEMENTS | c.ccaps&mysql.CLIENT_MULTI_RESULTS |
c.ccaps&mysql.CLIENT_PS_MULTI_RESULTS | c.ccaps&mysql.CLIENT_CONNECT_ATTRS |
c.ccaps&mysql.CLIENT_COMPRESS | c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM |
c.ccaps&mysql.CLIENT_LOCAL_FILES

// To enable TLS / SSL
if c.tlsConfig != nil {
capability |= CLIENT_SSL
capability |= mysql.CLIENT_SSL
}

auth, addNull, err := c.genAuthResponse(c.salt)
Expand All @@ -227,11 +227,11 @@ func (c *Conn) writeAuthHandshake() error {
// here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte
// see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer
var authRespLEIBuf [9]byte
authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth)))
authRespLEI := mysql.AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth)))
if len(authRespLEI) > 1 {
// if the length can not be written in 1 byte, it must be written as a
// length encoded integer
capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
capability |= mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
}

// packet length
Expand All @@ -248,16 +248,16 @@ func (c *Conn) writeAuthHandshake() error {
}
// db name
if len(c.db) > 0 {
capability |= CLIENT_CONNECT_WITH_DB
capability |= mysql.CLIENT_CONNECT_WITH_DB
length += len(c.db) + 1
}
// connection attributes
attrData := c.genAttributes()
if len(attrData) > 0 {
capability |= CLIENT_CONNECT_ATTRS
capability |= mysql.CLIENT_CONNECT_ATTRS
length += len(attrData)
}
if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
if c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
length++
}

Expand All @@ -279,7 +279,7 @@ func (c *Conn) writeAuthHandshake() error {
// use default collation id 255 here, is `utf8mb4_0900_ai_ci`
collationName := c.collation
if len(collationName) == 0 {
collationName = DEFAULT_COLLATION_NAME
collationName = mysql.DEFAULT_COLLATION_NAME
}
collation, err := charset.GetCollationByName(collationName)
if err != nil {
Expand Down Expand Up @@ -347,7 +347,7 @@ func (c *Conn) writeAuthHandshake() error {
pos += copy(data[pos:], attrData)
}

if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
if c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 {
// zstd_compression_level
data[pos] = 0x03
}
Expand Down
Loading
Loading