From d38bd7ed027c23ea213939c24da99b2e46c80342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Eeden?= Date: Fri, 7 Feb 2025 10:43:44 +0100 Subject: [PATCH] Change how we import the mysql package --- client/auth.go | 74 +++--- client/conn.go | 176 +++++++-------- client/resp.go | 108 ++++----- client/stmt.go | 106 ++++----- dump/dumper.go | 4 +- packet/conn.go | 60 ++--- replication/backup.go | 8 +- replication/binlogsyncer.go | 60 ++--- replication/event.go | 31 ++- replication/json_binary.go | 27 ++- replication/row_event.go | 272 +++++++++++------------ replication/transaction_payload_event.go | 13 +- server/auth.go | 14 +- server/auth_switch_response.go | 8 +- server/command.go | 53 +++-- server/conn.go | 21 +- server/handshake_resp.go | 37 +-- server/resp.go | 68 +++--- server/server_conf.go | 22 +- server/stmt.go | 88 ++++---- 20 files changed, 625 insertions(+), 625 deletions(-) diff --git a/client/auth.go b/client/auth.go index 5e1e4c45f..972a3acfc 100644 --- a/client/auth.go +++ b/client/auth.go @@ -6,16 +6,16 @@ import ( "encoding/binary" "fmt" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/charset" ) -const defaultAuthPluginName = AUTH_NATIVE_PASSWORD +const defaultAuthPluginName = mysql.AUTH_NATIVE_PASSWORD // defines the supported auth plugins -var supportedAuthPlugins = []string{AUTH_NATIVE_PASSWORD, AUTH_SHA256_PASSWORD, AUTH_CACHING_SHA2_PASSWORD} +var supportedAuthPlugins = []string{mysql.AUTH_NATIVE_PASSWORD, mysql.AUTH_SHA256_PASSWORD, mysql.AUTH_CACHING_SHA2_PASSWORD} // helper function to determine what auth methods are allowed by this client func authPluginAllowed(pluginName string) bool { @@ -38,12 +38,12 @@ func (c *Conn) readInitialHandshake() error { return errors.Trace(err) } - if data[0] == ERR_HEADER { + if data[0] == mysql.ERR_HEADER { return errors.Annotate(c.handleErrorPacket(data), "read initial handshake error") } - if data[0] != ClassicProtocolVersion { - if data[0] == XProtocolVersion { + if data[0] != mysql.ClassicProtocolVersion { + if data[0] == mysql.XProtocolVersion { return errors.Errorf( "invalid protocol version %d, expected 10. "+ "This might be X Protocol, make sure to connect to the right port", @@ -75,10 +75,10 @@ func (c *Conn) readInitialHandshake() error { // The lower 2 bytes of the Capabilities Flags c.capability = uint32(binary.LittleEndian.Uint16(data[pos : pos+2])) // check protocol - if c.capability&CLIENT_PROTOCOL_41 == 0 { + if c.capability&mysql.CLIENT_PROTOCOL_41 == 0 { return errors.New("the MySQL server can not support protocol 41 and above required by the client") } - if c.capability&CLIENT_SSL == 0 && c.tlsConfig != nil { + if c.capability&mysql.CLIENT_SSL == 0 && c.tlsConfig != nil { return errors.New("the MySQL Server does not support TLS required by the client") } pos += 2 @@ -97,7 +97,7 @@ func (c *Conn) readInitialHandshake() error { // length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0 authPluginDataLen := data[pos] - if (c.capability&CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) { + if (c.capability&mysql.CLIENT_PLUGIN_AUTH == 0) && (authPluginDataLen > 0) { return errors.Errorf("invalid auth plugin data filler %d", authPluginDataLen) } pos++ @@ -105,7 +105,7 @@ func (c *Conn) readInitialHandshake() error { // skip reserved (all [00] ?) pos += 10 - if c.capability&CLIENT_SECURE_CONNECTION != 0 { + if c.capability&mysql.CLIENT_SECURE_CONNECTION != 0 { // Rest of the plugin provided data (scramble) // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_handshake_v10.html @@ -125,7 +125,7 @@ func (c *Conn) readInitialHandshake() error { c.salt = append(c.salt, authPluginDataPart2...) } - if c.capability&CLIENT_PLUGIN_AUTH != 0 { + if c.capability&mysql.CLIENT_PLUGIN_AUTH != 0 { c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) pos += len(c.authPluginName) @@ -153,13 +153,13 @@ func (c *Conn) readInitialHandshake() error { func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) { // password hashing switch c.authPluginName { - case AUTH_NATIVE_PASSWORD: - return CalcPassword(authData[:20], []byte(c.password)), false, nil - case AUTH_CACHING_SHA2_PASSWORD: - return CalcCachingSha2Password(authData, c.password), false, nil - case AUTH_CLEAR_PASSWORD: + case mysql.AUTH_NATIVE_PASSWORD: + return mysql.CalcPassword(authData[:20], []byte(c.password)), false, nil + case mysql.AUTH_CACHING_SHA2_PASSWORD: + return mysql.CalcCachingSha2Password(authData, c.password), false, nil + case mysql.AUTH_CLEAR_PASSWORD: return []byte(c.password), true, nil - case AUTH_SHA256_PASSWORD: + case mysql.AUTH_SHA256_PASSWORD: if len(c.password) == 0 { return nil, true, nil } @@ -186,10 +186,10 @@ func (c *Conn) genAttributes() []byte { attrData := make([]byte, 0) for k, v := range c.attributes { - attrData = append(attrData, PutLengthEncodedString([]byte(k))...) - attrData = append(attrData, PutLengthEncodedString([]byte(v))...) + attrData = append(attrData, mysql.PutLengthEncodedString([]byte(k))...) + attrData = append(attrData, mysql.PutLengthEncodedString([]byte(v))...) } - return append(PutLengthEncodedInt(uint64(len(attrData))), attrData...) + return append(mysql.PutLengthEncodedInt(uint64(len(attrData))), attrData...) } // See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse @@ -199,23 +199,23 @@ func (c *Conn) writeAuthHandshake() error { } // Set default client capabilities that reflect the abilities of this library - capability := CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | - CLIENT_LONG_PASSWORD | CLIENT_TRANSACTIONS | CLIENT_PLUGIN_AUTH + capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION | + mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH // Adjust client capability flags based on server support - capability |= c.capability & CLIENT_LONG_FLAG - capability |= c.capability & CLIENT_QUERY_ATTRIBUTES + capability |= c.capability & mysql.CLIENT_LONG_FLAG + capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES // Adjust client capability flags on specific client requests // Only flags that would make any sense setting and aren't handled elsewhere // in the library are supported here - capability |= c.ccaps&CLIENT_FOUND_ROWS | c.ccaps&CLIENT_IGNORE_SPACE | - c.ccaps&CLIENT_MULTI_STATEMENTS | c.ccaps&CLIENT_MULTI_RESULTS | - c.ccaps&CLIENT_PS_MULTI_RESULTS | c.ccaps&CLIENT_CONNECT_ATTRS | - c.ccaps&CLIENT_COMPRESS | c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM | - c.ccaps&CLIENT_LOCAL_FILES + capability |= c.ccaps&mysql.CLIENT_FOUND_ROWS | c.ccaps&mysql.CLIENT_IGNORE_SPACE | + c.ccaps&mysql.CLIENT_MULTI_STATEMENTS | c.ccaps&mysql.CLIENT_MULTI_RESULTS | + c.ccaps&mysql.CLIENT_PS_MULTI_RESULTS | c.ccaps&mysql.CLIENT_CONNECT_ATTRS | + c.ccaps&mysql.CLIENT_COMPRESS | c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM | + c.ccaps&mysql.CLIENT_LOCAL_FILES // To enable TLS / SSL if c.tlsConfig != nil { - capability |= CLIENT_SSL + capability |= mysql.CLIENT_SSL } auth, addNull, err := c.genAuthResponse(c.salt) @@ -227,11 +227,11 @@ func (c *Conn) writeAuthHandshake() error { // here we use the Length-Encoded-Integer(LEI) as the data length may not fit into one byte // see: https://dev.mysql.com/doc/internals/en/integer.html#length-encoded-integer var authRespLEIBuf [9]byte - authRespLEI := AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth))) + authRespLEI := mysql.AppendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(auth))) if len(authRespLEI) > 1 { // if the length can not be written in 1 byte, it must be written as a // length encoded integer - capability |= CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + capability |= mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA } // packet length @@ -248,16 +248,16 @@ func (c *Conn) writeAuthHandshake() error { } // db name if len(c.db) > 0 { - capability |= CLIENT_CONNECT_WITH_DB + capability |= mysql.CLIENT_CONNECT_WITH_DB length += len(c.db) + 1 } // connection attributes attrData := c.genAttributes() if len(attrData) > 0 { - capability |= CLIENT_CONNECT_ATTRS + capability |= mysql.CLIENT_CONNECT_ATTRS length += len(attrData) } - if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + if c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { length++ } @@ -279,7 +279,7 @@ func (c *Conn) writeAuthHandshake() error { // use default collation id 255 here, is `utf8mb4_0900_ai_ci` collationName := c.collation if len(collationName) == 0 { - collationName = DEFAULT_COLLATION_NAME + collationName = mysql.DEFAULT_COLLATION_NAME } collation, err := charset.GetCollationByName(collationName) if err != nil { @@ -347,7 +347,7 @@ func (c *Conn) writeAuthHandshake() error { pos += copy(data[pos:], attrData) } - if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + if c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { // zstd_compression_level data[pos] = 0x03 } diff --git a/client/conn.go b/client/conn.go index 523113b4a..928e1dfd1 100644 --- a/client/conn.go +++ b/client/conn.go @@ -14,7 +14,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/parser/charset" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" "github.com/go-mysql-org/go-mysql/utils" ) @@ -58,20 +58,20 @@ type Conn struct { connectionID uint32 - queryAttributes []QueryAttribute + queryAttributes []mysql.QueryAttribute // Include the file + line as query attribute. The number set which frame in the stack should be used. includeLine int } // This function will be called for every row in resultset from ExecuteSelectStreaming. -type SelectPerRowCallback func(row []FieldValue) error +type SelectPerRowCallback func(row []mysql.FieldValue) error // This function will be called once per result from ExecuteSelectStreaming -type SelectPerResultCallback func(result *Result) error +type SelectPerResultCallback func(result *mysql.Result) error // This function will be called once per result from ExecuteMultiple -type ExecPerResultCallback func(result *Result, err error) +type ExecPerResultCallback func(result *mysql.Result, err error) func getNetProto(addr string) string { proto := "tcp" @@ -138,7 +138,7 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam c.proto = network // use default charset here, utf-8 - c.charset = DEFAULT_CHARSET + c.charset = mysql.DEFAULT_CHARSET // Apply configuration functions. for _, option := range options { @@ -161,10 +161,10 @@ func ConnectWithDialer(ctx context.Context, network, addr, user, password, dbNam return nil, errors.Trace(err) } - if c.ccaps&CLIENT_COMPRESS > 0 { - c.Conn.Compression = MYSQL_COMPRESS_ZLIB - } else if c.ccaps&CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { - c.Conn.Compression = MYSQL_COMPRESS_ZSTD + if c.ccaps&mysql.CLIENT_COMPRESS > 0 { + c.Conn.Compression = mysql.MYSQL_COMPRESS_ZLIB + } else if c.ccaps&mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM > 0 { + c.Conn.Compression = mysql.MYSQL_COMPRESS_ZSTD } // if a collation was set with a ID of > 255, then we need to call SET NAMES ... @@ -215,14 +215,14 @@ func (c *Conn) Close() error { // Quit sends COM_QUIT to the server and then closes the connection. Use Close() to directly close the connection. func (c *Conn) Quit() error { - if err := c.writeCommand(COM_QUIT); err != nil { + if err := c.writeCommand(mysql.COM_QUIT); err != nil { return err } return c.Close() } func (c *Conn) Ping() error { - if err := c.writeCommand(COM_PING); err != nil { + if err := c.writeCommand(mysql.COM_PING); err != nil { return errors.Trace(err) } @@ -265,7 +265,7 @@ func (c *Conn) UseDB(dbName string) error { return nil } - if err := c.writeCommandStr(COM_INIT_DB, dbName); err != nil { + if err := c.writeCommandStr(mysql.COM_INIT_DB, dbName); err != nil { return errors.Trace(err) } @@ -291,17 +291,17 @@ func (c *Conn) GetServerVersion() string { // of the server and returns 0 if they are equal, and 1 if the server version // is higher and -1 if the server version is lower. func (c *Conn) CompareServerVersion(v string) (int, error) { - return CompareServerVersions(c.serverVersion, v) + return mysql.CompareServerVersions(c.serverVersion, v) } -func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { +func (c *Conn) Execute(command string, args ...interface{}) (*mysql.Result, error) { if len(args) == 0 { return c.exec(command) } else { if s, err := c.Prepare(command); err != nil { return nil, errors.Trace(err) } else { - var r *Result + var r *mysql.Result r, err = s.Execute(args...) s.Close() return r, err @@ -315,13 +315,13 @@ func (c *Conn) Execute(command string, args ...interface{}) (*Result, error) { // When ExecuteMultiple is used, the connection should have the SERVER_MORE_RESULTS_EXISTS // flag set to signal the server multiple queries are executed. Handling the responses // is up to the implementation of perResultCallback. -func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*Result, error) { +func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCallback) (*mysql.Result, error) { if err := c.execSend(query); err != nil { return nil, errors.Trace(err) } var err error - var result *Result + var result *mysql.Result bs := utils.ByteSliceGet(16) defer utils.ByteSlicePut(bs) @@ -333,13 +333,13 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall } switch bs.B[0] { - case OK_HEADER: + case mysql.OK_HEADER: result, err = c.handleOKPacket(bs.B) - case ERR_HEADER: + case mysql.ERR_HEADER: err = c.handleErrorPacket(bytes.Repeat(bs.B, 1)) result = nil - case LocalInFile_HEADER: - err = ErrMalformPacket + case mysql.LocalInFile_HEADER: + err = mysql.ErrMalformPacket result = nil default: result, err = c.readResultset(bs.B, false) @@ -348,7 +348,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall perResultCallback(result, err) // if there was an error of this was the last result, stop looping - if err != nil || result.Status&SERVER_MORE_RESULTS_EXISTS == 0 { + if err != nil || result.Status&mysql.SERVER_MORE_RESULTS_EXISTS == 0 { break } } @@ -357,10 +357,10 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall // streaming session // if this would end up in WriteValue, it would just be ignored as all // responses should have been handled in perResultCallback - rs := NewResultset(0) - rs.Streaming = StreamingMultiple + rs := mysql.NewResultset(0) + rs.Streaming = mysql.StreamingMultiple rs.StreamingDone = true - return NewResult(rs), nil + return mysql.NewResult(rs), nil } // ExecuteSelectStreaming will call perRowCallback for every row in resultset @@ -368,7 +368,7 @@ func (c *Conn) ExecuteMultiple(query string, perResultCallback ExecPerResultCall // When given, perResultCallback will be called once per result // // ExecuteSelectStreaming should be used only for SELECT queries with a large response resultset for memory preserving. -func (c *Conn) ExecuteSelectStreaming(command string, result *Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { +func (c *Conn) ExecuteSelectStreaming(command string, result *mysql.Result, perRowCallback SelectPerRowCallback, perResultCallback SelectPerResultCallback) error { if err := c.execSend(command); err != nil { return errors.Trace(err) } @@ -425,13 +425,13 @@ func (c *Conn) GetCollation() string { } // FieldList uses COM_FIELD_LIST to get a list of fields from a table -func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) { - if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil { +func (c *Conn) FieldList(table string, wildcard string) ([]*mysql.Field, error) { + if err := c.writeCommandStrStr(mysql.COM_FIELD_LIST, table, wildcard); err != nil { return nil, errors.Trace(err) } - fs := make([]*Field, 0, 4) - var f *Field + fs := make([]*mysql.Field, 0, 4) + var f *mysql.Field for { data, err := c.ReadPacket() if err != nil { @@ -439,7 +439,7 @@ func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) { } // ERR Packet - if data[0] == ERR_HEADER { + if data[0] == mysql.ERR_HEADER { return nil, c.handleErrorPacket(data) } @@ -448,7 +448,7 @@ func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) { return fs, nil } - if f, err = FieldData(data).Parse(); err != nil { + if f, err = mysql.FieldData(data).Parse(); err != nil { return nil, errors.Trace(err) } fs = append(fs, f) @@ -466,12 +466,12 @@ func (c *Conn) SetAutoCommit() error { // IsAutoCommit returns true if SERVER_STATUS_AUTOCOMMIT is set func (c *Conn) IsAutoCommit() bool { - return c.status&SERVER_STATUS_AUTOCOMMIT > 0 + return c.status&mysql.SERVER_STATUS_AUTOCOMMIT > 0 } // IsInTransaction returns true if SERVER_STATUS_IN_TRANS is set func (c *Conn) IsInTransaction() bool { - return c.status&SERVER_STATUS_IN_TRANS > 0 + return c.status&mysql.SERVER_STATUS_IN_TRANS > 0 } func (c *Conn) GetCharset() string { @@ -482,7 +482,7 @@ func (c *Conn) GetConnectionID() uint32 { return c.connectionID } -func (c *Conn) HandleOKPacket(data []byte) *Result { +func (c *Conn) HandleOKPacket(data []byte) *mysql.Result { r, _ := c.handleOKPacket(data) return r } @@ -491,12 +491,12 @@ func (c *Conn) HandleErrorPacket(data []byte) error { return c.handleErrorPacket(data) } -func (c *Conn) ReadOKPacket() (*Result, error) { +func (c *Conn) ReadOKPacket() (*mysql.Result, error) { return c.readOK() } // Send COM_QUERY and read the result -func (c *Conn) exec(query string) (*Result, error) { +func (c *Conn) exec(query string) (*mysql.Result, error) { err := c.execSend(query) if err != nil { return nil, errors.Trace(err) @@ -510,11 +510,11 @@ func (c *Conn) execSend(query string) error { var buf bytes.Buffer defer clear(c.queryAttributes) - if c.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + if c.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0 { if c.includeLine >= 0 { _, file, line, ok := runtime.Caller(c.includeLine) if ok { - lineAttr := QueryAttribute{ + lineAttr := mysql.QueryAttribute{ Name: "_line", Value: fmt.Sprintf("%s:%d", file, line), } @@ -523,7 +523,7 @@ func (c *Conn) execSend(query string) error { } numParams := len(c.queryAttributes) - buf.Write(PutLengthEncodedInt(uint64(numParams))) + buf.Write(mysql.PutLengthEncodedInt(uint64(numParams))) buf.WriteByte(0x1) // parameter_set_count, unused if numParams > 0 { // null_bitmap, length: (num_params+7)/8 @@ -533,7 +533,7 @@ func (c *Conn) execSend(query string) error { buf.WriteByte(0x1) // new_params_bind_flag, unused for _, qa := range c.queryAttributes { buf.Write(qa.TypeAndFlag()) - buf.Write(PutLengthEncodedString([]byte(qa.Name))) + buf.Write(mysql.PutLengthEncodedString([]byte(qa.Name))) } for _, qa := range c.queryAttributes { buf.Write(qa.ValueBytes()) @@ -546,7 +546,7 @@ func (c *Conn) execSend(query string) error { return err } - if err := c.writeCommandBuf(COM_QUERY, buf.Bytes()); err != nil { + if err := c.writeCommandBuf(mysql.COM_QUERY, buf.Bytes()); err != nil { return errors.Trace(err) } @@ -567,69 +567,69 @@ func (c *Conn) CapabilityString() string { capability ^= field switch field { - case CLIENT_LONG_PASSWORD: + case mysql.CLIENT_LONG_PASSWORD: caps = append(caps, "CLIENT_LONG_PASSWORD") - case CLIENT_FOUND_ROWS: + case mysql.CLIENT_FOUND_ROWS: caps = append(caps, "CLIENT_FOUND_ROWS") - case CLIENT_LONG_FLAG: + case mysql.CLIENT_LONG_FLAG: caps = append(caps, "CLIENT_LONG_FLAG") - case CLIENT_CONNECT_WITH_DB: + case mysql.CLIENT_CONNECT_WITH_DB: caps = append(caps, "CLIENT_CONNECT_WITH_DB") - case CLIENT_NO_SCHEMA: + case mysql.CLIENT_NO_SCHEMA: caps = append(caps, "CLIENT_NO_SCHEMA") - case CLIENT_COMPRESS: + case mysql.CLIENT_COMPRESS: caps = append(caps, "CLIENT_COMPRESS") - case CLIENT_ODBC: + case mysql.CLIENT_ODBC: caps = append(caps, "CLIENT_ODBC") - case CLIENT_LOCAL_FILES: + case mysql.CLIENT_LOCAL_FILES: caps = append(caps, "CLIENT_LOCAL_FILES") - case CLIENT_IGNORE_SPACE: + case mysql.CLIENT_IGNORE_SPACE: caps = append(caps, "CLIENT_IGNORE_SPACE") - case CLIENT_PROTOCOL_41: + case mysql.CLIENT_PROTOCOL_41: caps = append(caps, "CLIENT_PROTOCOL_41") - case CLIENT_INTERACTIVE: + case mysql.CLIENT_INTERACTIVE: caps = append(caps, "CLIENT_INTERACTIVE") - case CLIENT_SSL: + case mysql.CLIENT_SSL: caps = append(caps, "CLIENT_SSL") - case CLIENT_IGNORE_SIGPIPE: + case mysql.CLIENT_IGNORE_SIGPIPE: caps = append(caps, "CLIENT_IGNORE_SIGPIPE") - case CLIENT_TRANSACTIONS: + case mysql.CLIENT_TRANSACTIONS: caps = append(caps, "CLIENT_TRANSACTIONS") - case CLIENT_RESERVED: + case mysql.CLIENT_RESERVED: caps = append(caps, "CLIENT_RESERVED") - case CLIENT_SECURE_CONNECTION: + case mysql.CLIENT_SECURE_CONNECTION: caps = append(caps, "CLIENT_SECURE_CONNECTION") - case CLIENT_MULTI_STATEMENTS: + case mysql.CLIENT_MULTI_STATEMENTS: caps = append(caps, "CLIENT_MULTI_STATEMENTS") - case CLIENT_MULTI_RESULTS: + case mysql.CLIENT_MULTI_RESULTS: caps = append(caps, "CLIENT_MULTI_RESULTS") - case CLIENT_PS_MULTI_RESULTS: + case mysql.CLIENT_PS_MULTI_RESULTS: caps = append(caps, "CLIENT_PS_MULTI_RESULTS") - case CLIENT_PLUGIN_AUTH: + case mysql.CLIENT_PLUGIN_AUTH: caps = append(caps, "CLIENT_PLUGIN_AUTH") - case CLIENT_CONNECT_ATTRS: + case mysql.CLIENT_CONNECT_ATTRS: caps = append(caps, "CLIENT_CONNECT_ATTRS") - case CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: + case mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA: caps = append(caps, "CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA") - case CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: + case mysql.CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS: caps = append(caps, "CLIENT_CAN_HANDLE_EXPIRED_PASSWORDS") - case CLIENT_SESSION_TRACK: + case mysql.CLIENT_SESSION_TRACK: caps = append(caps, "CLIENT_SESSION_TRACK") - case CLIENT_DEPRECATE_EOF: + case mysql.CLIENT_DEPRECATE_EOF: caps = append(caps, "CLIENT_DEPRECATE_EOF") - case CLIENT_OPTIONAL_RESULTSET_METADATA: + case mysql.CLIENT_OPTIONAL_RESULTSET_METADATA: caps = append(caps, "CLIENT_OPTIONAL_RESULTSET_METADATA") - case CLIENT_ZSTD_COMPRESSION_ALGORITHM: + case mysql.CLIENT_ZSTD_COMPRESSION_ALGORITHM: caps = append(caps, "CLIENT_ZSTD_COMPRESSION_ALGORITHM") - case CLIENT_QUERY_ATTRIBUTES: + case mysql.CLIENT_QUERY_ATTRIBUTES: caps = append(caps, "CLIENT_QUERY_ATTRIBUTES") - case MULTI_FACTOR_AUTHENTICATION: + case mysql.MULTI_FACTOR_AUTHENTICATION: caps = append(caps, "MULTI_FACTOR_AUTHENTICATION") - case CLIENT_CAPABILITY_EXTENSION: + case mysql.CLIENT_CAPABILITY_EXTENSION: caps = append(caps, "CLIENT_CAPABILITY_EXTENSION") - case CLIENT_SSL_VERIFY_SERVER_CERT: + case mysql.CLIENT_SSL_VERIFY_SERVER_CERT: caps = append(caps, "CLIENT_SSL_VERIFY_SERVER_CERT") - case CLIENT_REMEMBER_OPTIONS: + case mysql.CLIENT_REMEMBER_OPTIONS: caps = append(caps, "CLIENT_REMEMBER_OPTIONS") default: caps = append(caps, fmt.Sprintf("(%d)", field)) @@ -652,29 +652,29 @@ func (c *Conn) StatusString() string { status ^= field switch field { - case SERVER_STATUS_IN_TRANS: + case mysql.SERVER_STATUS_IN_TRANS: stats = append(stats, "SERVER_STATUS_IN_TRANS") - case SERVER_STATUS_AUTOCOMMIT: + case mysql.SERVER_STATUS_AUTOCOMMIT: stats = append(stats, "SERVER_STATUS_AUTOCOMMIT") - case SERVER_MORE_RESULTS_EXISTS: + case mysql.SERVER_MORE_RESULTS_EXISTS: stats = append(stats, "SERVER_MORE_RESULTS_EXISTS") - case SERVER_STATUS_NO_GOOD_INDEX_USED: + case mysql.SERVER_STATUS_NO_GOOD_INDEX_USED: stats = append(stats, "SERVER_STATUS_NO_GOOD_INDEX_USED") - case SERVER_STATUS_NO_INDEX_USED: + case mysql.SERVER_STATUS_NO_INDEX_USED: stats = append(stats, "SERVER_STATUS_NO_INDEX_USED") - case SERVER_STATUS_CURSOR_EXISTS: + case mysql.SERVER_STATUS_CURSOR_EXISTS: stats = append(stats, "SERVER_STATUS_CURSOR_EXISTS") - case SERVER_STATUS_LAST_ROW_SEND: + case mysql.SERVER_STATUS_LAST_ROW_SEND: stats = append(stats, "SERVER_STATUS_LAST_ROW_SEND") - case SERVER_STATUS_DB_DROPPED: + case mysql.SERVER_STATUS_DB_DROPPED: stats = append(stats, "SERVER_STATUS_DB_DROPPED") - case SERVER_STATUS_NO_BACKSLASH_ESCAPED: + case mysql.SERVER_STATUS_NO_BACKSLASH_ESCAPED: stats = append(stats, "SERVER_STATUS_NO_BACKSLASH_ESCAPED") - case SERVER_STATUS_METADATA_CHANGED: + case mysql.SERVER_STATUS_METADATA_CHANGED: stats = append(stats, "SERVER_STATUS_METADATA_CHANGED") - case SERVER_QUERY_WAS_SLOW: + case mysql.SERVER_QUERY_WAS_SLOW: stats = append(stats, "SERVER_QUERY_WAS_SLOW") - case SERVER_PS_OUT_PARAMS: + case mysql.SERVER_PS_OUT_PARAMS: stats = append(stats, "SERVER_PS_OUT_PARAMS") default: stats = append(stats, fmt.Sprintf("(%d)", field)) @@ -685,7 +685,7 @@ func (c *Conn) StatusString() string { } // SetQueryAttributes sets the query attributes to be send along with the next query -func (c *Conn) SetQueryAttributes(attrs ...QueryAttribute) error { +func (c *Conn) SetQueryAttributes(attrs ...mysql.QueryAttribute) error { c.queryAttributes = attrs return nil } diff --git a/client/resp.go b/client/resp.go index a82f82408..6cd7e9bcf 100644 --- a/client/resp.go +++ b/client/resp.go @@ -10,7 +10,7 @@ import ( "github.com/pingcap/errors" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" ) @@ -32,21 +32,21 @@ func (c *Conn) readUntilEOF() (err error) { } func (c *Conn) isEOFPacket(data []byte) bool { - return data[0] == EOF_HEADER && len(data) <= 5 + return data[0] == mysql.EOF_HEADER && len(data) <= 5 } -func (c *Conn) handleOKPacket(data []byte) (*Result, error) { +func (c *Conn) handleOKPacket(data []byte) (*mysql.Result, error) { var n int var pos = 1 - r := NewResultReserveResultset(0) + r := mysql.NewResultReserveResultset(0) - r.AffectedRows, _, n = LengthEncodedInt(data[pos:]) + r.AffectedRows, _, n = mysql.LengthEncodedInt(data[pos:]) pos += n - r.InsertId, _, n = LengthEncodedInt(data[pos:]) + r.InsertId, _, n = mysql.LengthEncodedInt(data[pos:]) pos += n - if c.capability&CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { r.Status = binary.LittleEndian.Uint16(data[pos:]) c.status = r.Status pos += 2 @@ -54,7 +54,7 @@ func (c *Conn) handleOKPacket(data []byte) (*Result, error) { //todo:strict_mode, check warnings as error r.Warnings = binary.LittleEndian.Uint16(data[pos:]) // pos += 2 - } else if c.capability&CLIENT_TRANSACTIONS > 0 { + } else if c.capability&mysql.CLIENT_TRANSACTIONS > 0 { r.Status = binary.LittleEndian.Uint16(data[pos:]) c.status = r.Status // pos += 2 @@ -67,14 +67,14 @@ func (c *Conn) handleOKPacket(data []byte) (*Result, error) { } func (c *Conn) handleErrorPacket(data []byte) error { - e := new(MyError) + e := new(mysql.MyError) var pos = 1 e.Code = binary.LittleEndian.Uint16(data[pos:]) pos += 2 - if c.capability&CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { // skip '#' pos++ e.State = utils.ByteSliceToString(data[pos : pos+5]) @@ -122,14 +122,14 @@ func (c *Conn) handleAuthResult() error { } // handle caching_sha2_password - if c.authPluginName == AUTH_CACHING_SHA2_PASSWORD { + if c.authPluginName == mysql.AUTH_CACHING_SHA2_PASSWORD { if data == nil { return nil // auth already succeeded } - if data[0] == CACHE_SHA2_FAST_AUTH { + if data[0] == mysql.CACHE_SHA2_FAST_AUTH { _, err = c.readOK() return err - } else if data[0] == CACHE_SHA2_FULL_AUTH { + } else if data[0] == mysql.CACHE_SHA2_FULL_AUTH { // need full authentication if c.tlsConfig != nil || c.proto == "unix" { if err = c.WriteClearAuthPacket(c.password); err != nil { @@ -145,7 +145,7 @@ func (c *Conn) handleAuthResult() error { } else { return errors.Errorf("invalid packet %x", data[0]) } - } else if c.authPluginName == AUTH_SHA256_PASSWORD { + } else if c.authPluginName == mysql.AUTH_SHA256_PASSWORD { if len(data) == 0 { return nil // auth already succeeded } @@ -174,18 +174,18 @@ func (c *Conn) readAuthResult() ([]byte, string, error) { // see: https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ // packet indicator switch data[0] { - case OK_HEADER: + case mysql.OK_HEADER: _, err := c.handleOKPacket(data) return nil, "", err - case MORE_DATE_HEADER: + case mysql.MORE_DATE_HEADER: return data[1:], "", err - case EOF_HEADER: + case mysql.EOF_HEADER: // server wants to switch auth if len(data) < 1 { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest - return nil, AUTH_MYSQL_OLD_PASSWORD, nil + return nil, mysql.AUTH_MYSQL_OLD_PASSWORD, nil } pluginEndIndex := bytes.IndexByte(data, 0x00) if pluginEndIndex < 0 { @@ -200,22 +200,22 @@ func (c *Conn) readAuthResult() ([]byte, string, error) { } } -func (c *Conn) readOK() (*Result, error) { +func (c *Conn) readOK() (*mysql.Result, error) { data, err := c.ReadPacket() if err != nil { return nil, errors.Trace(err) } - if data[0] == OK_HEADER { + if data[0] == mysql.OK_HEADER { return c.handleOKPacket(data) - } else if data[0] == ERR_HEADER { + } else if data[0] == mysql.ERR_HEADER { return nil, c.handleErrorPacket(data) } else { return nil, errors.New("invalid ok packet") } } -func (c *Conn) readResult(binary bool) (*Result, error) { +func (c *Conn) readResult(binary bool) (*mysql.Result, error) { bs := utils.ByteSliceGet(16) defer utils.ByteSlicePut(bs) var err error @@ -225,18 +225,18 @@ func (c *Conn) readResult(binary bool) (*Result, error) { } switch bs.B[0] { - case OK_HEADER: + case mysql.OK_HEADER: return c.handleOKPacket(bs.B) - case ERR_HEADER: + case mysql.ERR_HEADER: return nil, c.handleErrorPacket(bytes.Repeat(bs.B, 1)) - case LocalInFile_HEADER: - return nil, ErrMalformPacket + case mysql.LocalInFile_HEADER: + return nil, mysql.ErrMalformPacket default: return c.readResultset(bs.B, binary) } } -func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error { +func (c *Conn) readResultStreaming(binary bool, result *mysql.Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error { bs := utils.ByteSliceGet(16) defer utils.ByteSlicePut(bs) var err error @@ -246,7 +246,7 @@ func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectP } switch bs.B[0] { - case OK_HEADER: + case mysql.OK_HEADER: // https://dev.mysql.com/doc/internals/en/com-query-response.html // 14.6.4.1 COM_QUERY Response // If the number of columns in the resultset is 0, this is a OK_Packet. @@ -261,29 +261,29 @@ func (c *Conn) readResultStreaming(binary bool, result *Result, perRowCb SelectP result.InsertId = okResult.InsertId result.Warnings = okResult.Warnings if result.Resultset == nil { - result.Resultset = NewResultset(0) + result.Resultset = mysql.NewResultset(0) } else { result.Reset(0) } return nil - case ERR_HEADER: + case mysql.ERR_HEADER: return c.handleErrorPacket(bytes.Repeat(bs.B, 1)) - case LocalInFile_HEADER: - return ErrMalformPacket + case mysql.LocalInFile_HEADER: + return mysql.ErrMalformPacket default: return c.readResultsetStreaming(bs.B, binary, result, perRowCb, perResCb) } } -func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) { +func (c *Conn) readResultset(data []byte, binary bool) (*mysql.Result, error) { // column count - count, _, n := LengthEncodedInt(data) + count, _, n := mysql.LengthEncodedInt(data) if n-len(data) != 0 { - return nil, ErrMalformPacket + return nil, mysql.ErrMalformPacket } - result := NewResultReserveResultset(int(count)) + result := mysql.NewResultReserveResultset(int(count)) if err := c.readResultColumns(result); err != nil { return nil, errors.Trace(err) @@ -296,22 +296,22 @@ func (c *Conn) readResultset(data []byte, binary bool) (*Result, error) { return result, nil } -func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error { - columnCount, _, n := LengthEncodedInt(data) +func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *mysql.Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback) error { + columnCount, _, n := mysql.LengthEncodedInt(data) if n-len(data) != 0 { - return ErrMalformPacket + return mysql.ErrMalformPacket } if result.Resultset == nil { - result.Resultset = NewResultset(int(columnCount)) + result.Resultset = mysql.NewResultset(int(columnCount)) } else { // Reuse memory if can result.Reset(int(columnCount)) } // this is a streaming resultset - result.Resultset.Streaming = StreamingSelect + result.Resultset.Streaming = mysql.StreamingSelect if err := c.readResultColumns(result); err != nil { return errors.Trace(err) @@ -333,7 +333,7 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *Result, return nil } -func (c *Conn) readResultColumns(result *Result) (err error) { +func (c *Conn) readResultColumns(result *mysql.Result) (err error) { var i = 0 var data []byte @@ -347,7 +347,7 @@ func (c *Conn) readResultColumns(result *Result) (err error) { // EOF Packet if c.isEOFPacket(data) { - if c.capability&CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { result.Warnings = binary.LittleEndian.Uint16(data[1:]) // todo add strict_mode, warning will be treat as error result.Status = binary.LittleEndian.Uint16(data[3:]) @@ -355,14 +355,14 @@ func (c *Conn) readResultColumns(result *Result) (err error) { } if i != len(result.Fields) { - err = ErrMalformPacket + err = mysql.ErrMalformPacket } return err } if result.Fields[i] == nil { - result.Fields[i] = &Field{} + result.Fields[i] = &mysql.Field{} } err = result.Fields[i].Parse(data) if err != nil { @@ -375,7 +375,7 @@ func (c *Conn) readResultColumns(result *Result) (err error) { } } -func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) { +func (c *Conn) readResultRows(result *mysql.Result, isBinary bool) (err error) { var data []byte for { @@ -388,7 +388,7 @@ func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) { // EOF Packet if c.isEOFPacket(data) { - if c.capability&CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { result.Warnings = binary.LittleEndian.Uint16(data[1:]) // todo add strict_mode, warning will be treat as error result.Status = binary.LittleEndian.Uint16(data[3:]) @@ -398,7 +398,7 @@ func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) { break } - if data[0] == ERR_HEADER { + if data[0] == mysql.ERR_HEADER { return c.handleErrorPacket(data) } @@ -406,7 +406,7 @@ func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) { } if cap(result.Values) < len(result.RowDatas) { - result.Values = make([][]FieldValue, len(result.RowDatas)) + result.Values = make([][]mysql.FieldValue, len(result.RowDatas)) } else { result.Values = result.Values[:len(result.RowDatas)] } @@ -422,10 +422,10 @@ func (c *Conn) readResultRows(result *Result, isBinary bool) (err error) { return nil } -func (c *Conn) readResultRowsStreaming(result *Result, isBinary bool, perRowCb SelectPerRowCallback) (err error) { +func (c *Conn) readResultRowsStreaming(result *mysql.Result, isBinary bool, perRowCb SelectPerRowCallback) (err error) { var ( data []byte - row []FieldValue + row []mysql.FieldValue ) for { @@ -436,7 +436,7 @@ func (c *Conn) readResultRowsStreaming(result *Result, isBinary bool, perRowCb S // EOF Packet if c.isEOFPacket(data) { - if c.capability&CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { result.Warnings = binary.LittleEndian.Uint16(data[1:]) // todo add strict_mode, warning will be treat as error result.Status = binary.LittleEndian.Uint16(data[3:]) @@ -446,12 +446,12 @@ func (c *Conn) readResultRowsStreaming(result *Result, isBinary bool, perRowCb S break } - if data[0] == ERR_HEADER { + if data[0] == mysql.ERR_HEADER { return c.handleErrorPacket(data) } // Parse this row - row, err = RowData(data).Parse(result.Fields, isBinary, row) + row, err = mysql.RowData(data).Parse(result.Fields, isBinary, row) if err != nil { return errors.Trace(err) } diff --git a/client/stmt.go b/client/stmt.go index b9866704d..b831050e6 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -7,7 +7,7 @@ import ( "math" "runtime" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" "github.com/pingcap/errors" ) @@ -33,7 +33,7 @@ func (s *Stmt) WarningsNum() int { return s.warnings } -func (s *Stmt) Execute(args ...interface{}) (*Result, error) { +func (s *Stmt) Execute(args ...interface{}) (*mysql.Result, error) { if err := s.write(args...); err != nil { return nil, errors.Trace(err) } @@ -41,7 +41,7 @@ func (s *Stmt) Execute(args ...interface{}) (*Result, error) { return s.conn.readResult(true) } -func (s *Stmt) ExecuteSelectStreaming(result *Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback, args ...interface{}) error { +func (s *Stmt) ExecuteSelectStreaming(result *mysql.Result, perRowCb SelectPerRowCallback, perResCb SelectPerResultCallback, args ...interface{}) error { if err := s.write(args...); err != nil { return errors.Trace(err) } @@ -50,7 +50,7 @@ func (s *Stmt) ExecuteSelectStreaming(result *Result, perRowCb SelectPerRowCallb } func (s *Stmt) Close() error { - if err := s.conn.writeCommandUint32(COM_STMT_CLOSE, s.id); err != nil { + if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.id); err != nil { return errors.Trace(err) } @@ -66,10 +66,10 @@ func (s *Stmt) write(args ...interface{}) error { return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) } - if (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) { + if (s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) { _, file, line, ok := runtime.Caller(s.conn.includeLine) if ok { - lineAttr := QueryAttribute{ + lineAttr := mysql.QueryAttribute{ Name: "_line", Value: fmt.Sprintf("%s:%d", file, line), } @@ -93,7 +93,7 @@ func (s *Stmt) write(args ...interface{}) error { for i := range args { if args[i] == nil { nullBitmap[i/8] |= 1 << (uint(i) % 8) - paramTypes[i] = []byte{MYSQL_TYPE_NULL} + paramTypes[i] = []byte{mysql.MYSQL_TYPE_NULL} paramNames[i] = []byte{0} // length encoded, no name paramFlags[i] = []byte{0} continue @@ -103,62 +103,62 @@ func (s *Stmt) write(args ...interface{}) error { switch v := args[i].(type) { case int8: - paramTypes[i] = []byte{MYSQL_TYPE_TINY} + paramTypes[i] = []byte{mysql.MYSQL_TYPE_TINY} paramValues[i] = []byte{byte(v)} case int16: - paramTypes[i] = []byte{MYSQL_TYPE_SHORT} - paramValues[i] = Uint16ToBytes(uint16(v)) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_SHORT} + paramValues[i] = mysql.Uint16ToBytes(uint16(v)) case int32: - paramTypes[i] = []byte{MYSQL_TYPE_LONG} - paramValues[i] = Uint32ToBytes(uint32(v)) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_LONG} + paramValues[i] = mysql.Uint32ToBytes(uint32(v)) case int: - paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} - paramValues[i] = Uint64ToBytes(uint64(v)) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_LONGLONG} + paramValues[i] = mysql.Uint64ToBytes(uint64(v)) case int64: - paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} - paramValues[i] = Uint64ToBytes(uint64(v)) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_LONGLONG} + paramValues[i] = mysql.Uint64ToBytes(uint64(v)) case uint8: - paramTypes[i] = []byte{MYSQL_TYPE_TINY} - paramFlags[i] = []byte{PARAM_UNSIGNED} + paramTypes[i] = []byte{mysql.MYSQL_TYPE_TINY} + paramFlags[i] = []byte{mysql.PARAM_UNSIGNED} paramValues[i] = []byte{v} case uint16: - paramTypes[i] = []byte{MYSQL_TYPE_SHORT} - paramFlags[i] = []byte{PARAM_UNSIGNED} - paramValues[i] = Uint16ToBytes(v) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_SHORT} + paramFlags[i] = []byte{mysql.PARAM_UNSIGNED} + paramValues[i] = mysql.Uint16ToBytes(v) case uint32: - paramTypes[i] = []byte{MYSQL_TYPE_LONG} - paramFlags[i] = []byte{PARAM_UNSIGNED} - paramValues[i] = Uint32ToBytes(v) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_LONG} + paramFlags[i] = []byte{mysql.PARAM_UNSIGNED} + paramValues[i] = mysql.Uint32ToBytes(v) case uint: - paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} - paramFlags[i] = []byte{PARAM_UNSIGNED} - paramValues[i] = Uint64ToBytes(uint64(v)) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_LONGLONG} + paramFlags[i] = []byte{mysql.PARAM_UNSIGNED} + paramValues[i] = mysql.Uint64ToBytes(uint64(v)) case uint64: - paramTypes[i] = []byte{MYSQL_TYPE_LONGLONG} - paramFlags[i] = []byte{PARAM_UNSIGNED} - paramValues[i] = Uint64ToBytes(v) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_LONGLONG} + paramFlags[i] = []byte{mysql.PARAM_UNSIGNED} + paramValues[i] = mysql.Uint64ToBytes(v) case bool: - paramTypes[i] = []byte{MYSQL_TYPE_TINY} + paramTypes[i] = []byte{mysql.MYSQL_TYPE_TINY} if v { paramValues[i] = []byte{1} } else { paramValues[i] = []byte{0} } case float32: - paramTypes[i] = []byte{MYSQL_TYPE_FLOAT} - paramValues[i] = Uint32ToBytes(math.Float32bits(v)) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_FLOAT} + paramValues[i] = mysql.Uint32ToBytes(math.Float32bits(v)) case float64: - paramTypes[i] = []byte{MYSQL_TYPE_DOUBLE} - paramValues[i] = Uint64ToBytes(math.Float64bits(v)) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_DOUBLE} + paramValues[i] = mysql.Uint64ToBytes(math.Float64bits(v)) case string: - paramTypes[i] = []byte{MYSQL_TYPE_STRING} - paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_STRING} + paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...) case []byte: - paramTypes[i] = []byte{MYSQL_TYPE_STRING} - paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_STRING} + paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...) case json.RawMessage: - paramTypes[i] = []byte{MYSQL_TYPE_STRING} - paramValues[i] = append(PutLengthEncodedInt(uint64(len(v))), v...) + paramTypes[i] = []byte{mysql.MYSQL_TYPE_STRING} + paramValues[i] = append(mysql.PutLengthEncodedInt(uint64(len(v))), v...) default: return fmt.Errorf("invalid argument type %T", args[i]) } @@ -174,7 +174,7 @@ func (s *Stmt) write(args ...interface{}) error { paramTypes[(i + paramsNum)] = []byte{tf[0]} paramFlags[i+paramsNum] = []byte{tf[1]} paramValues[i+paramsNum] = qa.ValueBytes() - paramNames[i+paramsNum] = PutLengthEncodedString([]byte(qa.Name)) + paramNames[i+paramsNum] = mysql.PutLengthEncodedString([]byte(qa.Name)) } data := utils.BytesBufferGet() @@ -186,22 +186,22 @@ func (s *Stmt) write(args ...interface{}) error { } data.Write([]byte{0, 0, 0, 0}) - data.WriteByte(COM_STMT_EXECUTE) + data.WriteByte(mysql.COM_STMT_EXECUTE) data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) - flags := CURSOR_TYPE_NO_CURSOR + flags := mysql.CURSOR_TYPE_NO_CURSOR if paramsNum > 0 { - flags |= PARAMETER_COUNT_AVAILABLE + flags |= mysql.PARAMETER_COUNT_AVAILABLE } data.WriteByte(flags) //iteration-count, always 1 data.Write([]byte{1, 0, 0, 0}) - if paramsNum > 0 || (s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 && (flags&PARAMETER_COUNT_AVAILABLE > 0)) { - if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + if paramsNum > 0 || (s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0 && (flags&mysql.PARAMETER_COUNT_AVAILABLE > 0)) { + if s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0 { paramsNum += len(s.conn.queryAttributes) - data.Write(PutLengthEncodedInt(uint64(paramsNum))) + data.Write(mysql.PutLengthEncodedInt(uint64(paramsNum))) } if paramsNum > 0 { data.Write(nullBitmap) @@ -214,7 +214,7 @@ func (s *Stmt) write(args ...interface{}) error { data.Write(paramTypes[i]) data.Write(paramFlags[i]) - if s.conn.capability&CLIENT_QUERY_ATTRIBUTES > 0 { + if s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0 { data.Write(paramNames[i]) } } @@ -233,7 +233,7 @@ func (s *Stmt) write(args ...interface{}) error { } func (c *Conn) Prepare(query string) (*Stmt, error) { - if err := c.writeCommandStr(COM_STMT_PREPARE, query); err != nil { + if err := c.writeCommandStr(mysql.COM_STMT_PREPARE, query); err != nil { return nil, errors.Trace(err) } @@ -242,10 +242,10 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { return nil, errors.Trace(err) } - if data[0] == ERR_HEADER { + if data[0] == mysql.ERR_HEADER { return nil, c.handleErrorPacket(data) - } else if data[0] != OK_HEADER { - return nil, ErrMalformPacket + } else if data[0] != mysql.OK_HEADER { + return nil, mysql.ErrMalformPacket } s := new(Stmt) diff --git a/dump/dumper.go b/dump/dumper.go index 91179a0c1..537bedc40 100644 --- a/dump/dumper.go +++ b/dump/dumper.go @@ -10,7 +10,7 @@ import ( "regexp" "strings" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/errors" "github.com/siddontang/go-log/log" "github.com/siddontang/go-log/loggers" @@ -81,7 +81,7 @@ func NewDumper(executionPath string, addr string, user string, password string) d.Password = password d.Tables = make([]string, 0, 16) d.Databases = make([]string, 0, 16) - d.Charset = DEFAULT_CHARSET + d.Charset = mysql.DEFAULT_CHARSET d.IgnoreTables = make(map[string][]string) d.ExtraOptions = make([]string, 0, 5) d.masterDataSkipped = false diff --git a/packet/conn.go b/packet/conn.go index 1d49848f3..964e145ae 100644 --- a/packet/conn.go +++ b/packet/conn.go @@ -14,7 +14,7 @@ import ( "time" "github.com/go-mysql-org/go-mysql/compress" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" "github.com/klauspost/compress/zstd" "github.com/pingcap/errors" @@ -104,7 +104,7 @@ func (c *Conn) ReadPacketReuseMem(dst []byte) ([]byte, error) { utils.BytesBufferPut(buf) }() - if c.Compression != MYSQL_COMPRESS_NONE { + if c.Compression != mysql.MYSQL_COMPRESS_NONE { // it's possible that we're using compression but the server response with a compressed // packet with uncompressed length of 0. In this case we leave compressedReader nil. The // compressedReaderActive flag is important to track the state of the reader, allowing @@ -155,7 +155,7 @@ func (c *Conn) newCompressedPacketReader() (io.Reader, error) { } } if _, err := io.ReadFull(c.reader, c.compressedHeader[:7]); err != nil { - return nil, errors.Wrapf(ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) + return nil, errors.Wrapf(mysql.ErrBadConn, "io.ReadFull(compressedHeader) failed. err %v", err) } compressedSequence := c.compressedHeader[3] @@ -169,9 +169,9 @@ func (c *Conn) newCompressedPacketReader() (io.Reader, error) { if uncompressedLength > 0 { limitedReader := io.LimitReader(c.reader, int64(compressedLength)) switch c.Compression { - case MYSQL_COMPRESS_ZLIB: + case mysql.MYSQL_COMPRESS_ZLIB: return compress.GetPooledZlibReader(limitedReader) - case MYSQL_COMPRESS_ZSTD: + case mysql.MYSQL_COMPRESS_ZSTD: return zstd.NewReader(limitedReader) } } @@ -180,7 +180,7 @@ func (c *Conn) newCompressedPacketReader() (io.Reader, error) { } func (c *Conn) currentPacketReader() io.Reader { - if c.Compression == MYSQL_COMPRESS_NONE || c.compressedReader == nil { + if c.Compression == mysql.MYSQL_COMPRESS_NONE || c.compressedReader == nil { return c.reader } else { return c.compressedReader @@ -212,7 +212,7 @@ func (c *Conn) copyN(dst io.Writer, n int64) (int64, error) { // bytes are read. In this case, and when we have compression then advance // the sequence number and reset the compressed reader to continue reading // the remaining bytes in the next compressed packet. - if c.Compression != MYSQL_COMPRESS_NONE && + if c.Compression != mysql.MYSQL_COMPRESS_NONE && (goErrors.Is(err, io.ErrUnexpectedEOF) || goErrors.Is(err, io.EOF)) { // we have read to EOF and read an incomplete uncompressed packet // so advance the compressed sequence number and reset the compressed reader @@ -249,7 +249,7 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { // buffer, since copyN is capable of getting the next compressed // packet and updating the Conn state with a new compressedReader. if _, err := c.copyN(b, 4); err != nil { - return errors.Wrapf(ErrBadConn, "io.ReadFull(header) failed. err %v", err) + return errors.Wrapf(mysql.ErrBadConn, "io.ReadFull(header) failed. err %v", err) } else { // copy was successful so copy the 4 bytes from the buffer to the header copy(c.header[:4], b.Bytes()[:4]) @@ -270,11 +270,11 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { } if n, err := c.copyN(w, int64(length)); err != nil { - return errors.Wrapf(ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length) + return errors.Wrapf(mysql.ErrBadConn, "io.CopyN failed. err %v, copied %v, expected %v", err, n, length) } else if n != int64(length) { - return errors.Wrapf(ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length) + return errors.Wrapf(mysql.ErrBadConn, "io.CopyN failed(n != int64(length)). %v bytes copied, while %v expected", n, length) } else { - if length < MaxPayloadLen { + if length < mysql.MaxPayloadLen { return nil } @@ -290,21 +290,23 @@ func (c *Conn) ReadPacketTo(w io.Writer) error { func (c *Conn) WritePacket(data []byte) error { length := len(data) - 4 - for length >= MaxPayloadLen { + for length >= mysql.MaxPayloadLen { data[0] = 0xff data[1] = 0xff data[2] = 0xff data[3] = c.Sequence - if n, err := c.writeWithTimeout(data[:4+MaxPayloadLen]); err != nil { - return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. err %v", err) - } else if n != (4 + MaxPayloadLen) { - return errors.Wrapf(ErrBadConn, "Write(payload portion) failed. only %v bytes written, while %v expected", n, 4+MaxPayloadLen) + if n, err := c.writeWithTimeout(data[:4+mysql.MaxPayloadLen]); err != nil { + return errors.Wrapf(mysql.ErrBadConn, + "Write(payload portion) failed. err %v", err) + } else if n != (4 + mysql.MaxPayloadLen) { + return errors.Wrapf(mysql.ErrBadConn, + "Write(payload portion) failed. only %v bytes written, while %v expected", n, 4+mysql.MaxPayloadLen) } else { c.Sequence++ - length -= MaxPayloadLen - data = data[MaxPayloadLen:] + length -= mysql.MaxPayloadLen + data = data[mysql.MaxPayloadLen:] } } @@ -314,17 +316,17 @@ func (c *Conn) WritePacket(data []byte) error { data[3] = c.Sequence switch c.Compression { - case MYSQL_COMPRESS_NONE: + case mysql.MYSQL_COMPRESS_NONE: if n, err := c.writeWithTimeout(data); err != nil { - return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) + return errors.Wrapf(mysql.ErrBadConn, "Write failed. err %v", err) } else if n != len(data) { - return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) + return errors.Wrapf(mysql.ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) } - case MYSQL_COMPRESS_ZLIB, MYSQL_COMPRESS_ZSTD: + case mysql.MYSQL_COMPRESS_ZLIB, mysql.MYSQL_COMPRESS_ZSTD: if n, err := c.writeCompressed(data); err != nil { - return errors.Wrapf(ErrBadConn, "Write failed. err %v", err) + return errors.Wrapf(mysql.ErrBadConn, "Write failed. err %v", err) } else if n != len(data) { - return errors.Wrapf(ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) + return errors.Wrapf(mysql.ErrBadConn, "Write failed. only %v bytes written, while %v expected", n, len(data)) } c.compressedReaderActive = false @@ -335,7 +337,7 @@ func (c *Conn) WritePacket(data []byte) error { c.compressedReader = nil } default: - return errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set") + return errors.Wrapf(mysql.ErrBadConn, "Write failed. Unsuppored compression algorithm set") } c.Sequence++ @@ -365,12 +367,12 @@ func (c *Conn) writeCompressed(data []byte) (n int, err error) { defer utils.BytesBufferPut(payload) switch c.Compression { - case MYSQL_COMPRESS_ZLIB: + case mysql.MYSQL_COMPRESS_ZLIB: w, err = compress.GetPooledZlibWriter(payload) - case MYSQL_COMPRESS_ZSTD: + case mysql.MYSQL_COMPRESS_ZSTD: w, err = zstd.NewWriter(payload) default: - return 0, errors.Wrapf(ErrBadConn, "Write failed. Unsuppored compression algorithm set") + return 0, errors.Wrapf(mysql.ErrBadConn, "Write failed. Unsuppored compression algorithm set") } if err != nil { return 0, err @@ -472,7 +474,7 @@ func (c *Conn) WritePublicKeyAuthPacket(password string, cipher []byte) error { } func (c *Conn) WriteEncryptedPassword(password string, seed []byte, pub *rsa.PublicKey) error { - enc, err := EncryptPassword(password, seed, pub) + enc, err := mysql.EncryptPassword(password, seed, pub) if err != nil { return errors.Wrap(err, "EncryptPassword failed") } diff --git a/replication/backup.go b/replication/backup.go index f1d7dd28b..681c0e21f 100644 --- a/replication/backup.go +++ b/replication/backup.go @@ -8,12 +8,12 @@ import ( "sync" "time" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/errors" ) // StartBackup starts the backup process for the binary log and writes to the backup directory. -func (b *BinlogSyncer) StartBackup(backupDir string, p Position, timeout time.Duration) error { +func (b *BinlogSyncer) StartBackup(backupDir string, p mysql.Position, timeout time.Duration) error { err := os.MkdirAll(backupDir, 0755) if err != nil { return errors.Trace(err) @@ -36,7 +36,7 @@ func (b *BinlogSyncer) StartBackup(backupDir string, p Position, timeout time.Du // - timeout: The maximum duration to wait for new binlog events before stopping the backup process. // If set to 0, a default very long timeout (30 days) is used instead. // - handler: A function that takes a binlog filename and returns an WriteCloser for writing raw events to. -func (b *BinlogSyncer) StartBackupWithHandler(p Position, timeout time.Duration, +func (b *BinlogSyncer) StartBackupWithHandler(p mysql.Position, timeout time.Duration, handler func(binlogFilename string) (io.WriteCloser, error)) (retErr error) { if timeout == 0 { // a very long timeout here @@ -89,7 +89,7 @@ func (b *BinlogSyncer) StartBackupWithHandler(p Position, timeout time.Duration, } // StartSynchronousBackup starts the backup process using the SynchronousEventHandler in the BinlogSyncerConfig. -func (b *BinlogSyncer) StartSynchronousBackup(p Position, timeout time.Duration) error { +func (b *BinlogSyncer) StartSynchronousBackup(p mysql.Position, timeout time.Duration) error { if b.cfg.SynchronousEventHandler == nil { return errors.New("SynchronousEventHandler must be set in BinlogSyncerConfig to use StartSynchronousBackup") } diff --git a/replication/binlogsyncer.go b/replication/binlogsyncer.go index 41a02e8d0..5105a3dfd 100644 --- a/replication/binlogsyncer.go +++ b/replication/binlogsyncer.go @@ -18,7 +18,7 @@ import ( "github.com/siddontang/go-log/loggers" "github.com/go-mysql-org/go-mysql/client" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" ) @@ -151,9 +151,9 @@ type BinlogSyncer struct { parser *BinlogParser - nextPos Position + nextPos mysql.Position - prevGset, currGset GTIDSet + prevGset, currGset mysql.GTIDSet // instead of GTIDSet.Clone, use this to speed up calculate prevGset prevMySQLGTIDEvent *GTIDEvent @@ -331,7 +331,7 @@ func (b *BinlogSyncer) registerSlave() error { } } - if b.cfg.Flavor == MariaDBFlavor { + if b.cfg.Flavor == mysql.MariaDBFlavor { // Refer https://github.com/alibaba/canal/wiki/BinlogChange(MariaDB5&10) // Tell the server that we understand GTIDs by setting our slave capability // to MARIA_SLAVE_CAPABILITY_GTID = 4 (MariaDB >= 10.0.1). @@ -422,13 +422,13 @@ func (b *BinlogSyncer) startDumpStream() *BinlogStreamer { } // GetNextPosition returns the next position of the syncer -func (b *BinlogSyncer) GetNextPosition() Position { +func (b *BinlogSyncer) GetNextPosition() mysql.Position { return b.nextPos } func (b *BinlogSyncer) checkFlavor() { serverVersion := b.c.GetServerVersion() - if b.cfg.Flavor != MariaDBFlavor && + if b.cfg.Flavor != mysql.MariaDBFlavor && strings.Contains(b.c.GetServerVersion(), "MariaDB") { // Setting the flavor to `mysql` causes MariaDB to try and behave // in a MySQL compatible way. In this mode MariaDB won't use @@ -439,7 +439,7 @@ func (b *BinlogSyncer) checkFlavor() { } // StartSync starts syncing from the `pos` position. -func (b *BinlogSyncer) StartSync(pos Position) (*BinlogStreamer, error) { +func (b *BinlogSyncer) StartSync(pos mysql.Position) (*BinlogStreamer, error) { b.cfg.Logger.Infof("begin to sync binlog from position %s", pos) b.m.Lock() @@ -459,7 +459,7 @@ func (b *BinlogSyncer) StartSync(pos Position) (*BinlogStreamer, error) { } // StartSyncGTID starts syncing from the `gset` GTIDSet. -func (b *BinlogSyncer) StartSyncGTID(gset GTIDSet) (*BinlogStreamer, error) { +func (b *BinlogSyncer) StartSyncGTID(gset mysql.GTIDSet) (*BinlogStreamer, error) { b.cfg.Logger.Infof("begin to sync binlog from GTID set %s", gset) b.prevMySQLGTIDEvent = nil @@ -482,7 +482,7 @@ func (b *BinlogSyncer) StartSyncGTID(gset GTIDSet) (*BinlogStreamer, error) { var err error switch b.cfg.Flavor { - case MariaDBFlavor: + case mysql.MariaDBFlavor: err = b.writeBinlogDumpMariadbGTIDCommand(gset) default: // default use MySQL @@ -498,13 +498,13 @@ func (b *BinlogSyncer) StartSyncGTID(gset GTIDSet) (*BinlogStreamer, error) { return b.startDumpStream(), nil } -func (b *BinlogSyncer) writeBinlogDumpCommand(p Position) error { +func (b *BinlogSyncer) writeBinlogDumpCommand(p mysql.Position) error { b.c.ResetSequence() data := make([]byte, 4+1+4+2+4+len(p.Name)) pos := 4 - data[pos] = COM_BINLOG_DUMP + data[pos] = mysql.COM_BINLOG_DUMP pos++ binary.LittleEndian.PutUint32(data[pos:], p.Pos) @@ -521,15 +521,15 @@ func (b *BinlogSyncer) writeBinlogDumpCommand(p Position) error { return b.c.WritePacket(data) } -func (b *BinlogSyncer) writeBinlogDumpMysqlGTIDCommand(gset GTIDSet) error { - p := Position{Name: "", Pos: 4} +func (b *BinlogSyncer) writeBinlogDumpMysqlGTIDCommand(gset mysql.GTIDSet) error { + p := mysql.Position{Name: "", Pos: 4} gtidData := gset.Encode() b.c.ResetSequence() data := make([]byte, 4+1+2+4+4+len(p.Name)+8+4+len(gtidData)) pos := 4 - data[pos] = COM_BINLOG_DUMP_GTID + data[pos] = mysql.COM_BINLOG_DUMP_GTID pos++ binary.LittleEndian.PutUint16(data[pos:], 0) @@ -557,7 +557,7 @@ func (b *BinlogSyncer) writeBinlogDumpMysqlGTIDCommand(gset GTIDSet) error { return b.c.WritePacket(data) } -func (b *BinlogSyncer) writeBinlogDumpMariadbGTIDCommand(gset GTIDSet) error { +func (b *BinlogSyncer) writeBinlogDumpMariadbGTIDCommand(gset mysql.GTIDSet) error { // Copy from vitess startPos := gset.String() @@ -578,7 +578,7 @@ func (b *BinlogSyncer) writeBinlogDumpMariadbGTIDCommand(gset GTIDSet) error { } // Since we use @slave_connect_state, the file and position here are ignored. - return b.writeBinlogDumpCommand(Position{Name: "", Pos: 0}) + return b.writeBinlogDumpCommand(mysql.Position{Name: "", Pos: 0}) } // localHostname returns the hostname that register slave would register as. @@ -599,7 +599,7 @@ func (b *BinlogSyncer) writeRegisterSlaveCommand() error { data := make([]byte, 4+1+4+1+len(hostname)+1+len(b.cfg.User)+1+2+4+4) pos := 4 - data[pos] = COM_REGISTER_SLAVE + data[pos] = mysql.COM_REGISTER_SLAVE pos++ binary.LittleEndian.PutUint32(data[pos:], b.cfg.ServerID) @@ -632,7 +632,7 @@ func (b *BinlogSyncer) writeRegisterSlaveCommand() error { return b.c.WritePacket(data) } -func (b *BinlogSyncer) replySemiSyncACK(p Position) error { +func (b *BinlogSyncer) replySemiSyncACK(p mysql.Position) error { b.c.ResetSequence() data := make([]byte, 4+1+8+len(p.Name)) @@ -681,7 +681,7 @@ func (b *BinlogSyncer) retrySync() error { return nil } -func (b *BinlogSyncer) prepareSyncPos(pos Position) error { +func (b *BinlogSyncer) prepareSyncPos(pos mysql.Position) error { // always start from position 4 if pos.Pos < 4 { pos.Pos = 4 @@ -698,7 +698,7 @@ func (b *BinlogSyncer) prepareSyncPos(pos Position) error { return nil } -func (b *BinlogSyncer) prepareSyncGTID(gset GTIDSet) error { +func (b *BinlogSyncer) prepareSyncGTID(gset mysql.GTIDSet) error { var err error // re establishing network connection here and will start getting binlog events from "gset + 1", thus until first @@ -710,7 +710,7 @@ func (b *BinlogSyncer) prepareSyncGTID(gset GTIDSet) error { } switch b.cfg.Flavor { - case MariaDBFlavor: + case mysql.MariaDBFlavor: err = b.writeBinlogDumpMariadbGTIDCommand(gset) default: // default use MySQL @@ -726,7 +726,7 @@ func (b *BinlogSyncer) prepareSyncGTID(gset GTIDSet) error { func (b *BinlogSyncer) onStream(s *BinlogStreamer) { defer func() { if e := recover(); e != nil { - s.closeWithError(fmt.Errorf("Err: %v\n Stack: %s", e, Pstack())) + s.closeWithError(fmt.Errorf("Err: %v\n Stack: %s", e, mysql.Pstack())) } b.wg.Done() }() @@ -797,7 +797,7 @@ func (b *BinlogSyncer) onStream(s *BinlogStreamer) { b.retryCount = 0 switch data[0] { - case OK_HEADER: + case mysql.OK_HEADER: // Parse the event e, needACK, err := b.parseEvent(data) if err != nil { @@ -811,11 +811,11 @@ func (b *BinlogSyncer) onStream(s *BinlogStreamer) { s.closeWithError(err) return } - case ERR_HEADER: + case mysql.ERR_HEADER: err = b.c.HandleErrorPacket(data) s.closeWithError(err) return - case EOF_HEADER: + case mysql.EOF_HEADER: // refer to https://dev.mysql.com/doc/internals/en/com-binlog-dump.html#binlog-dump-non-block // when COM_BINLOG_DUMP command use BINLOG_DUMP_NON_BLOCK flag, // if there is no more event to send an EOF_Packet instead of blocking the connection @@ -878,13 +878,13 @@ func (b *BinlogSyncer) handleEventAndACK(s *BinlogStreamer, e *BinlogEvent, need if err != nil { return errors.Trace(err) } - b.currGset.(*MysqlGTIDSet).AddGTID(u, event.GNO) + b.currGset.(*mysql.MysqlGTIDSet).AddGTID(u, event.GNO) if b.prevMySQLGTIDEvent != nil { u, err = uuid.FromBytes(b.prevMySQLGTIDEvent.SID) if err != nil { return errors.Trace(err) } - b.prevGset.(*MysqlGTIDSet).AddGTID(u, b.prevMySQLGTIDEvent.GNO) + b.prevGset.(*mysql.MysqlGTIDSet).AddGTID(u, b.prevMySQLGTIDEvent.GNO) } b.prevMySQLGTIDEvent = event @@ -896,7 +896,7 @@ func (b *BinlogSyncer) handleEventAndACK(s *BinlogStreamer, e *BinlogEvent, need b.currGset = b.prevGset.Clone() } prev := b.currGset.Clone() - err := b.currGset.(*MariadbGTIDSet).AddSet(&event.GTID) + err := b.currGset.(*mysql.MariadbGTIDSet).AddSet(&event.GTID) if err != nil { return errors.Trace(err) } @@ -942,7 +942,7 @@ func (b *BinlogSyncer) handleEventAndACK(s *BinlogStreamer, e *BinlogEvent, need } // getCurrentGtidSet returns a clone of the current GTID set. -func (b *BinlogSyncer) getCurrentGtidSet() GTIDSet { +func (b *BinlogSyncer) getCurrentGtidSet() mysql.GTIDSet { if b.currGset != nil { return b.currGset.Clone() } @@ -981,7 +981,7 @@ func (b *BinlogSyncer) killConnection(conn *client.Conn, id uint32) { if _, err := conn.Execute(cmd); err != nil { b.cfg.Logger.Errorf("kill connection %d error %v", id, err) // Unknown thread id - if code := ErrorCode(err.Error()); code != ER_NO_SUCH_THREAD { + if code := mysql.ErrorCode(err.Error()); code != mysql.ER_NO_SUCH_THREAD { b.cfg.Logger.Error(errors.Trace(err)) } } diff --git a/replication/event.go b/replication/event.go index c151f32ba..d16aa574c 100644 --- a/replication/event.go +++ b/replication/event.go @@ -11,10 +11,9 @@ import ( "time" "unicode" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/google/uuid" "github.com/pingcap/errors" - - . "github.com/go-mysql-org/go-mysql/mysql" ) const ( @@ -103,7 +102,7 @@ func (h *EventHeader) Decode(data []byte) error { func (h *EventHeader) Dump(w io.Writer) { fmt.Fprintf(w, "=== %s ===\n", h.EventType) - fmt.Fprintf(w, "Date: %s\n", time.Unix(int64(h.Timestamp), 0).Format(TimeFormat)) + fmt.Fprintf(w, "Date: %s\n", time.Unix(int64(h.Timestamp), 0).Format(mysql.TimeFormat)) fmt.Fprintf(w, "Log position: %d\n", h.LogPos) fmt.Fprintf(w, "Event size: %d\n", h.EventSize) } @@ -335,7 +334,7 @@ type XIDEvent struct { XID uint64 // in fact XIDEvent dosen't have the GTIDSet information, just for beneficial to use - GSet GTIDSet + GSet mysql.GTIDSet } func (e *XIDEvent) Decode(data []byte) error { @@ -363,7 +362,7 @@ type QueryEvent struct { compressed bool // in fact QueryEvent dosen't have the GTIDSet information, just for beneficial to use - GSet GTIDSet + GSet mysql.GTIDSet } func (e *QueryEvent) Decode(data []byte) error { @@ -394,7 +393,7 @@ func (e *QueryEvent) Decode(data []byte) error { pos++ if e.compressed { - decompressedQuery, err := DecompressMariadbData(data[pos:]) + decompressedQuery, err := mysql.DecompressMariadbData(data[pos:]) if err != nil { return err } @@ -461,12 +460,12 @@ func (e *GTIDEvent) Decode(data []byte) error { if len(data)-pos < 7 { return nil } - e.ImmediateCommitTimestamp = FixedLengthInt(data[pos : pos+7]) + e.ImmediateCommitTimestamp = mysql.FixedLengthInt(data[pos : pos+7]) pos += 7 if (e.ImmediateCommitTimestamp & (uint64(1) << 55)) != 0 { // If the most significant bit set, another 7 byte follows representing OriginalCommitTimestamp e.ImmediateCommitTimestamp &= ^(uint64(1) << 55) - e.OriginalCommitTimestamp = FixedLengthInt(data[pos : pos+7]) + e.OriginalCommitTimestamp = mysql.FixedLengthInt(data[pos : pos+7]) pos += 7 } else { // Otherwise OriginalCommitTimestamp == ImmediateCommitTimestamp @@ -478,7 +477,7 @@ func (e *GTIDEvent) Decode(data []byte) error { return nil } var n int - e.TransactionLength, _, n = LengthEncodedInt(data[pos:]) + e.TransactionLength, _, n = mysql.LengthEncodedInt(data[pos:]) pos += n // IMMEDIATE_SERVER_VERSION_LENGTH = 4 @@ -524,12 +523,12 @@ func (e *GTIDEvent) Dump(w io.Writer) { fmt.Fprintln(w) } -func (e *GTIDEvent) GTIDNext() (GTIDSet, error) { +func (e *GTIDEvent) GTIDNext() (mysql.GTIDSet, error) { u, err := uuid.FromBytes(e.SID) if err != nil { return nil, err } - return ParseMysqlGTIDSet(strings.Join([]string{u.String(), strconv.FormatInt(e.GNO, 10)}, ":")) + return mysql.ParseMysqlGTIDSet(strings.Join([]string{u.String(), strconv.FormatInt(e.GNO, 10)}, ":")) } // ImmediateCommitTime returns the commit time of this trx on the immediate server @@ -655,7 +654,7 @@ func (e *MariadbBinlogCheckPointEvent) Dump(w io.Writer) { } type MariadbGTIDEvent struct { - GTID MariadbGTID + GTID mysql.MariadbGTID Flags byte CommitID uint64 } @@ -695,12 +694,12 @@ func (e *MariadbGTIDEvent) Dump(w io.Writer) { fmt.Fprintln(w) } -func (e *MariadbGTIDEvent) GTIDNext() (GTIDSet, error) { - return ParseMariadbGTIDSet(e.GTID.String()) +func (e *MariadbGTIDEvent) GTIDNext() (mysql.GTIDSet, error) { + return mysql.ParseMariadbGTIDSet(e.GTID.String()) } type MariadbGTIDListEvent struct { - GTIDs []MariadbGTID + GTIDs []mysql.MariadbGTID } func (e *MariadbGTIDListEvent) Decode(data []byte) error { @@ -710,7 +709,7 @@ func (e *MariadbGTIDListEvent) Decode(data []byte) error { count := v & uint32((1<<28)-1) - e.GTIDs = make([]MariadbGTID, count) + e.GTIDs = make([]mysql.MariadbGTID, count) for i := uint32(0); i < count; i++ { e.GTIDs[i].DomainID = binary.LittleEndian.Uint32(data[pos:]) diff --git a/replication/json_binary.go b/replication/json_binary.go index 34186f09e..7c70cb890 100644 --- a/replication/json_binary.go +++ b/replication/json_binary.go @@ -4,11 +4,10 @@ import ( "fmt" "math" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" "github.com/goccy/go-json" "github.com/pingcap/errors" - - . "github.com/go-mysql-org/go-mysql/mysql" ) const ( @@ -340,7 +339,7 @@ func (d *jsonBinaryDecoder) decodeInt16(data []byte) int16 { return 0 } - v := ParseBinaryInt16(data[0:2]) + v := mysql.ParseBinaryInt16(data[0:2]) return v } @@ -349,7 +348,7 @@ func (d *jsonBinaryDecoder) decodeUint16(data []byte) uint16 { return 0 } - v := ParseBinaryUint16(data[0:2]) + v := mysql.ParseBinaryUint16(data[0:2]) return v } @@ -358,7 +357,7 @@ func (d *jsonBinaryDecoder) decodeInt32(data []byte) int32 { return 0 } - v := ParseBinaryInt32(data[0:4]) + v := mysql.ParseBinaryInt32(data[0:4]) return v } @@ -367,7 +366,7 @@ func (d *jsonBinaryDecoder) decodeUint32(data []byte) uint32 { return 0 } - v := ParseBinaryUint32(data[0:4]) + v := mysql.ParseBinaryUint32(data[0:4]) return v } @@ -376,7 +375,7 @@ func (d *jsonBinaryDecoder) decodeInt64(data []byte) int64 { return 0 } - v := ParseBinaryInt64(data[0:8]) + v := mysql.ParseBinaryInt64(data[0:8]) return v } @@ -385,7 +384,7 @@ func (d *jsonBinaryDecoder) decodeUint64(data []byte) uint64 { return 0 } - v := ParseBinaryUint64(data[0:8]) + v := mysql.ParseBinaryUint64(data[0:8]) return v } @@ -394,7 +393,7 @@ func (d *jsonBinaryDecoder) decodeDouble(data []byte) float64 { return 0 } - v := ParseBinaryFloat64(data[0:8]) + v := mysql.ParseBinaryFloat64(data[0:8]) return v } @@ -432,11 +431,11 @@ func (d *jsonBinaryDecoder) decodeOpaque(data []byte) interface{} { data = data[n : l+n] switch tp { - case MYSQL_TYPE_NEWDECIMAL: + case mysql.MYSQL_TYPE_NEWDECIMAL: return d.decodeDecimal(data) - case MYSQL_TYPE_TIME: + case mysql.MYSQL_TYPE_TIME: return d.decodeTime(data) - case MYSQL_TYPE_DATE, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIMESTAMP: + case mysql.MYSQL_TYPE_DATE, mysql.MYSQL_TYPE_DATETIME, mysql.MYSQL_TYPE_TIMESTAMP: return d.decodeDateTime(data) default: return utils.ByteSliceToString(data) @@ -554,7 +553,7 @@ func (e *RowsEvent) decodeJsonPartialBinary(data []byte) (*JsonDiff, error) { } data = data[1:] - pathLength, _, n := LengthEncodedInt(data) + pathLength, _, n := mysql.LengthEncodedInt(data) data = data[n:] path := data[:pathLength] @@ -570,7 +569,7 @@ func (e *RowsEvent) decodeJsonPartialBinary(data []byte) (*JsonDiff, error) { return diff, nil } - valueLength, _, n := LengthEncodedInt(data) + valueLength, _, n := mysql.LengthEncodedInt(data) data = data[n:] d, err := e.decodeJsonBinary(data[:valueLength]) diff --git a/replication/row_event.go b/replication/row_event.go index 0eebe6914..b35d02cbd 100644 --- a/replication/row_event.go +++ b/replication/row_event.go @@ -12,7 +12,7 @@ import ( "github.com/pingcap/errors" "github.com/shopspring/decimal" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/utils" ) @@ -90,7 +90,7 @@ type TableMapEvent struct { func (e *TableMapEvent) Decode(data []byte) error { pos := 0 - e.TableID = FixedLengthInt(data[0:e.tableIDSize]) + e.TableID = mysql.FixedLengthInt(data[0:e.tableIDSize]) pos += e.tableIDSize e.Flags = binary.LittleEndian.Uint16(data[pos:]) @@ -115,7 +115,7 @@ func (e *TableMapEvent) Decode(data []byte) error { pos++ var n int - e.ColumnCount, _, n = LengthEncodedInt(data[pos:]) + e.ColumnCount, _, n = mysql.LengthEncodedInt(data[pos:]) pos += n e.ColumnType = data[pos : pos+int(e.ColumnCount)] @@ -123,7 +123,7 @@ func (e *TableMapEvent) Decode(data []byte) error { var err error var metaData []byte - if metaData, _, n, err = LengthEncodedString(data[pos:]); err != nil { + if metaData, _, n, err = mysql.LengthEncodedString(data[pos:]); err != nil { return errors.Trace(err) } @@ -206,39 +206,39 @@ func (e *TableMapEvent) decodeMeta(data []byte) error { e.ColumnMeta = make([]uint16, e.ColumnCount) for i, t := range e.ColumnType { switch t { - case MYSQL_TYPE_STRING: + case mysql.MYSQL_TYPE_STRING: var x = uint16(data[pos]) << 8 // real type x += uint16(data[pos+1]) // pack or field length e.ColumnMeta[i] = x pos += 2 - case MYSQL_TYPE_NEWDECIMAL: + case mysql.MYSQL_TYPE_NEWDECIMAL: var x = uint16(data[pos]) << 8 // precision x += uint16(data[pos+1]) // decimals e.ColumnMeta[i] = x pos += 2 - case MYSQL_TYPE_VAR_STRING, - MYSQL_TYPE_VARCHAR, - MYSQL_TYPE_BIT: + case mysql.MYSQL_TYPE_VAR_STRING, + mysql.MYSQL_TYPE_VARCHAR, + mysql.MYSQL_TYPE_BIT: e.ColumnMeta[i] = binary.LittleEndian.Uint16(data[pos:]) pos += 2 - case MYSQL_TYPE_BLOB, - MYSQL_TYPE_DOUBLE, - MYSQL_TYPE_FLOAT, - MYSQL_TYPE_GEOMETRY, - MYSQL_TYPE_JSON: + case mysql.MYSQL_TYPE_BLOB, + mysql.MYSQL_TYPE_DOUBLE, + mysql.MYSQL_TYPE_FLOAT, + mysql.MYSQL_TYPE_GEOMETRY, + mysql.MYSQL_TYPE_JSON: e.ColumnMeta[i] = uint16(data[pos]) pos++ - case MYSQL_TYPE_TIME2, - MYSQL_TYPE_DATETIME2, - MYSQL_TYPE_TIMESTAMP2: + case mysql.MYSQL_TYPE_TIME2, + mysql.MYSQL_TYPE_DATETIME2, + mysql.MYSQL_TYPE_TIMESTAMP2: e.ColumnMeta[i] = uint16(data[pos]) pos++ - case MYSQL_TYPE_NEWDATE, - MYSQL_TYPE_ENUM, - MYSQL_TYPE_SET, - MYSQL_TYPE_TINY_BLOB, - MYSQL_TYPE_MEDIUM_BLOB, - MYSQL_TYPE_LONG_BLOB: + case mysql.MYSQL_TYPE_NEWDATE, + mysql.MYSQL_TYPE_ENUM, + mysql.MYSQL_TYPE_SET, + mysql.MYSQL_TYPE_TINY_BLOB, + mysql.MYSQL_TYPE_MEDIUM_BLOB, + mysql.MYSQL_TYPE_LONG_BLOB: return errors.Errorf("unsupport type in binlog %d", t) default: e.ColumnMeta[i] = 0 @@ -256,7 +256,7 @@ func (e *TableMapEvent) decodeOptionalMeta(data []byte) (err error) { t := data[pos] pos++ - l, _, n := LengthEncodedInt(data[pos:]) + l, _, n := mysql.LengthEncodedInt(data[pos:]) pos += n v := data[pos : pos+int(l)] @@ -337,7 +337,7 @@ func (e *TableMapEvent) decodeOptionalMeta(data []byte) (err error) { func (e *TableMapEvent) decodeIntSeq(v []byte) (ret []uint64, err error) { p := 0 for p < len(v) { - i, _, n := LengthEncodedInt(v[p:]) + i, _, n := mysql.LengthEncodedInt(v[p:]) p += n ret = append(ret, i) } @@ -374,11 +374,11 @@ func (e *TableMapEvent) decodeColumnNames(v []byte) error { func (e *TableMapEvent) decodeStrValue(v []byte) (ret [][][]byte, err error) { p := 0 for p < len(v) { - nVal, _, n := LengthEncodedInt(v[p:]) + nVal, _, n := mysql.LengthEncodedInt(v[p:]) p += n vals := make([][]byte, 0, int(nVal)) for i := 0; i < int(nVal); i++ { - val, _, n, err := LengthEncodedString(v[p:]) + val, _, n, err := mysql.LengthEncodedString(v[p:]) if err != nil { return nil, err } @@ -393,7 +393,7 @@ func (e *TableMapEvent) decodeStrValue(v []byte) (ret [][][]byte, err error) { func (e *TableMapEvent) decodeSimplePrimaryKey(v []byte) error { p := 0 for p < len(v) { - i, _, n := LengthEncodedInt(v[p:]) + i, _, n := mysql.LengthEncodedInt(v[p:]) e.PrimaryKey = append(e.PrimaryKey, i) e.PrimaryKeyPrefix = append(e.PrimaryKeyPrefix, 0) p += n @@ -404,10 +404,10 @@ func (e *TableMapEvent) decodeSimplePrimaryKey(v []byte) error { func (e *TableMapEvent) decodePrimaryKeyWithPrefix(v []byte) error { p := 0 for p < len(v) { - i, _, n := LengthEncodedInt(v[p:]) + i, _, n := mysql.LengthEncodedInt(v[p:]) e.PrimaryKey = append(e.PrimaryKey, i) p += n - i, _, n = LengthEncodedInt(v[p:]) + i, _, n = mysql.LengthEncodedInt(v[p:]) e.PrimaryKeyPrefix = append(e.PrimaryKeyPrefix, i) p += n } @@ -776,14 +776,14 @@ func (e *TableMapEvent) realType(i int) byte { typ := e.ColumnType[i] switch typ { - case MYSQL_TYPE_STRING: + case mysql.MYSQL_TYPE_STRING: rtyp := byte(e.ColumnMeta[i] >> 8) - if rtyp == MYSQL_TYPE_ENUM || rtyp == MYSQL_TYPE_SET { + if rtyp == mysql.MYSQL_TYPE_ENUM || rtyp == mysql.MYSQL_TYPE_SET { return rtyp } - case MYSQL_TYPE_DATE: - return MYSQL_TYPE_NEWDATE + case mysql.MYSQL_TYPE_DATE: + return mysql.MYSQL_TYPE_NEWDATE } return typ @@ -791,14 +791,14 @@ func (e *TableMapEvent) realType(i int) byte { func (e *TableMapEvent) IsNumericColumn(i int) bool { switch e.realType(i) { - case MYSQL_TYPE_TINY, - MYSQL_TYPE_SHORT, - MYSQL_TYPE_INT24, - MYSQL_TYPE_LONG, - MYSQL_TYPE_LONGLONG, - MYSQL_TYPE_NEWDECIMAL, - MYSQL_TYPE_FLOAT, - MYSQL_TYPE_DOUBLE: + case mysql.MYSQL_TYPE_TINY, + mysql.MYSQL_TYPE_SHORT, + mysql.MYSQL_TYPE_INT24, + mysql.MYSQL_TYPE_LONG, + mysql.MYSQL_TYPE_LONGLONG, + mysql.MYSQL_TYPE_NEWDECIMAL, + mysql.MYSQL_TYPE_FLOAT, + mysql.MYSQL_TYPE_DOUBLE: return true default: @@ -811,13 +811,13 @@ func (e *TableMapEvent) IsNumericColumn(i int) bool { // (JSON is an alias for LONGTEXT in mariadb: https://mariadb.com/kb/en/json-data-type/) func (e *TableMapEvent) IsCharacterColumn(i int) bool { switch e.realType(i) { - case MYSQL_TYPE_STRING, - MYSQL_TYPE_VAR_STRING, - MYSQL_TYPE_VARCHAR, - MYSQL_TYPE_BLOB: + case mysql.MYSQL_TYPE_STRING, + mysql.MYSQL_TYPE_VAR_STRING, + mysql.MYSQL_TYPE_VARCHAR, + mysql.MYSQL_TYPE_BLOB: return true - case MYSQL_TYPE_GEOMETRY: + case mysql.MYSQL_TYPE_GEOMETRY: if e.flavor == "mariadb" { return true } @@ -829,27 +829,27 @@ func (e *TableMapEvent) IsCharacterColumn(i int) bool { } func (e *TableMapEvent) IsEnumColumn(i int) bool { - return e.realType(i) == MYSQL_TYPE_ENUM + return e.realType(i) == mysql.MYSQL_TYPE_ENUM } func (e *TableMapEvent) IsSetColumn(i int) bool { - return e.realType(i) == MYSQL_TYPE_SET + return e.realType(i) == mysql.MYSQL_TYPE_SET } func (e *TableMapEvent) IsGeometryColumn(i int) bool { - return e.realType(i) == MYSQL_TYPE_GEOMETRY + return e.realType(i) == mysql.MYSQL_TYPE_GEOMETRY } func (e *TableMapEvent) IsEnumOrSetColumn(i int) bool { rtyp := e.realType(i) - return rtyp == MYSQL_TYPE_ENUM || rtyp == MYSQL_TYPE_SET + return rtyp == mysql.MYSQL_TYPE_ENUM || rtyp == mysql.MYSQL_TYPE_SET } // JsonColumnCount returns the number of JSON columns in this table func (e *TableMapEvent) JsonColumnCount() uint64 { count := uint64(0) for _, t := range e.ColumnType { - if t == MYSQL_TYPE_JSON { + if t == mysql.MYSQL_TYPE_JSON { count++ } } @@ -864,32 +864,32 @@ const RowsEventStmtEndFlag = 0x01 // UPDATE_ROWS_EVENT, etc. // RowsEvent.Rows saves the rows data, and the MySQL type to golang type mapping // is -// - MYSQL_TYPE_NULL: nil -// - MYSQL_TYPE_LONG: int32 -// - MYSQL_TYPE_TINY: int8 -// - MYSQL_TYPE_SHORT: int16 -// - MYSQL_TYPE_INT24: int32 -// - MYSQL_TYPE_LONGLONG: int64 -// - MYSQL_TYPE_NEWDECIMAL: string / "github.com/shopspring/decimal".Decimal -// - MYSQL_TYPE_FLOAT: float32 -// - MYSQL_TYPE_DOUBLE: float64 -// - MYSQL_TYPE_BIT: int64 -// - MYSQL_TYPE_TIMESTAMP: string / time.Time -// - MYSQL_TYPE_TIMESTAMP2: string / time.Time -// - MYSQL_TYPE_DATETIME: string / time.Time -// - MYSQL_TYPE_DATETIME2: string / time.Time -// - MYSQL_TYPE_TIME: string -// - MYSQL_TYPE_TIME2: string -// - MYSQL_TYPE_DATE: string -// - MYSQL_TYPE_YEAR: int -// - MYSQL_TYPE_ENUM: int64 -// - MYSQL_TYPE_SET: int64 -// - MYSQL_TYPE_BLOB: []byte -// - MYSQL_TYPE_VARCHAR: string -// - MYSQL_TYPE_VAR_STRING: string -// - MYSQL_TYPE_STRING: string -// - MYSQL_TYPE_JSON: []byte / *replication.JsonDiff -// - MYSQL_TYPE_GEOMETRY: []byte +// - mysql.MYSQL_TYPE_NULL: nil +// - mysql.MYSQL_TYPE_LONG: int32 +// - mysql.MYSQL_TYPE_TINY: int8 +// - mysql.MYSQL_TYPE_SHORT: int16 +// - mysql.MYSQL_TYPE_INT24: int32 +// - mysql.MYSQL_TYPE_LONGLONG: int64 +// - mysql.MYSQL_TYPE_NEWDECIMAL: string / "github.com/shopspring/decimal".Decimal +// - mysql.MYSQL_TYPE_FLOAT: float32 +// - mysql.MYSQL_TYPE_DOUBLE: float64 +// - mysql.MYSQL_TYPE_BIT: int64 +// - mysql.MYSQL_TYPE_TIMESTAMP: string / time.Time +// - mysql.MYSQL_TYPE_TIMESTAMP2: string / time.Time +// - mysql.MYSQL_TYPE_DATETIME: string / time.Time +// - mysql.MYSQL_TYPE_DATETIME2: string / time.Time +// - mysql.MYSQL_TYPE_TIME: string +// - mysql.MYSQL_TYPE_TIME2: string +// - mysql.MYSQL_TYPE_DATE: string +// - mysql.MYSQL_TYPE_YEAR: int +// - mysql.MYSQL_TYPE_ENUM: int64 +// - mysql.MYSQL_TYPE_SET: int64 +// - mysql.MYSQL_TYPE_BLOB: []byte +// - mysql.MYSQL_TYPE_VARCHAR: string +// - mysql.MYSQL_TYPE_VAR_STRING: string +// - mysql.MYSQL_TYPE_STRING: string +// - mysql.MYSQL_TYPE_JSON: []byte / *replication.JsonDiff +// - mysql.MYSQL_TYPE_GEOMETRY: []byte type RowsEvent struct { // 0, 1, 2 Version int @@ -983,7 +983,7 @@ const ( func (e *RowsEvent) DecodeHeader(data []byte) (int, error) { pos := 0 - e.TableID = FixedLengthInt(data[0:e.tableIDSize]) + e.TableID = mysql.FixedLengthInt(data[0:e.tableIDSize]) pos += e.tableIDSize e.Flags = binary.LittleEndian.Uint16(data[pos:]) @@ -1002,7 +1002,7 @@ func (e *RowsEvent) DecodeHeader(data []byte) (int, error) { } var n int - e.ColumnCount, _, n = LengthEncodedInt(data[pos:]) + e.ColumnCount, _, n = mysql.LengthEncodedInt(data[pos:]) pos += n bitCount := bitmapByteSize(int(e.ColumnCount)) @@ -1051,7 +1051,7 @@ func (e *RowsEvent) decodeExtraData(data []byte) (err2 error) { func (e *RowsEvent) DecodeData(pos int, data []byte) (err2 error) { if e.compressed { - data, err2 = DecompressMariadbData(data[pos:]) + data, err2 = mysql.DecompressMariadbData(data[pos:]) if err2 != nil { //nolint:nakedret return @@ -1135,7 +1135,7 @@ func (e *RowsEvent) decodeImage(data []byte, bitmap []byte, rowImageType EnumRow var partialBitmap []byte if e.eventType == PARTIAL_UPDATE_ROWS_EVENT && rowImageType == EnumRowImageTypeUpdateAI { - binlogRowValueOptions, _, n := LengthEncodedInt(data[pos:]) // binlog_row_value_options + binlogRowValueOptions, _, n := mysql.LengthEncodedInt(data[pos:]) // binlog_row_value_options pos += n isPartialJsonUpdate = EnumBinlogRowValueOptions(binlogRowValueOptions)&EnumBinlogRowValueOptionsPartialJsonUpdates != 0 if isPartialJsonUpdate { @@ -1171,7 +1171,7 @@ func (e *RowsEvent) decodeImage(data []byte, bitmap []byte, rowImageType EnumRow */ isPartial := isPartialJsonUpdate && (rowImageType == EnumRowImageTypeUpdateAI) && - (e.Table.ColumnType[i] == MYSQL_TYPE_JSON) && + (e.Table.ColumnType[i] == mysql.MYSQL_TYPE_JSON) && isBitSetIncr(partialBitmap, &partialBitmapIndex) if !isBitSet(bitmap, i) { @@ -1218,7 +1218,7 @@ func (e *RowsEvent) parseFracTime(t interface{}) interface{} { func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16, isPartial bool) (v interface{}, n int, err error) { var length = 0 - if tp == MYSQL_TYPE_STRING { + if tp == mysql.MYSQL_TYPE_STRING { if meta >= 256 { b0 := uint8(meta >> 8) b1 := uint8(meta & 0xFF) @@ -1236,40 +1236,40 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16, isPartial boo } switch tp { - case MYSQL_TYPE_NULL: + case mysql.MYSQL_TYPE_NULL: return nil, 0, nil - case MYSQL_TYPE_LONG: + case mysql.MYSQL_TYPE_LONG: n = 4 - v = ParseBinaryInt32(data) - case MYSQL_TYPE_TINY: + v = mysql.ParseBinaryInt32(data) + case mysql.MYSQL_TYPE_TINY: n = 1 - v = ParseBinaryInt8(data) - case MYSQL_TYPE_SHORT: + v = mysql.ParseBinaryInt8(data) + case mysql.MYSQL_TYPE_SHORT: n = 2 - v = ParseBinaryInt16(data) - case MYSQL_TYPE_INT24: + v = mysql.ParseBinaryInt16(data) + case mysql.MYSQL_TYPE_INT24: n = 3 - v = ParseBinaryInt24(data) - case MYSQL_TYPE_LONGLONG: + v = mysql.ParseBinaryInt24(data) + case mysql.MYSQL_TYPE_LONGLONG: n = 8 - v = ParseBinaryInt64(data) - case MYSQL_TYPE_NEWDECIMAL: + v = mysql.ParseBinaryInt64(data) + case mysql.MYSQL_TYPE_NEWDECIMAL: prec := uint8(meta >> 8) scale := uint8(meta & 0xFF) v, n, err = decodeDecimal(data, int(prec), int(scale), e.useDecimal) - case MYSQL_TYPE_FLOAT: + case mysql.MYSQL_TYPE_FLOAT: n = 4 - v = ParseBinaryFloat32(data) - case MYSQL_TYPE_DOUBLE: + v = mysql.ParseBinaryFloat32(data) + case mysql.MYSQL_TYPE_DOUBLE: n = 8 - v = ParseBinaryFloat64(data) - case MYSQL_TYPE_BIT: + v = mysql.ParseBinaryFloat64(data) + case mysql.MYSQL_TYPE_BIT: nbits := ((meta >> 8) * 8) + (meta & 0xFF) n = int(nbits+7) / 8 // use int64 for bit v, err = decodeBit(data, int(nbits), n) - case MYSQL_TYPE_TIMESTAMP: + case mysql.MYSQL_TYPE_TIMESTAMP: n = 4 t := binary.LittleEndian.Uint32(data) if t == 0 { @@ -1281,10 +1281,10 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16, isPartial boo timestampStringLocation: e.timestampStringLocation, }) } - case MYSQL_TYPE_TIMESTAMP2: + case mysql.MYSQL_TYPE_TIMESTAMP2: v, n, err = decodeTimestamp2(data, meta, e.timestampStringLocation) v = e.parseFracTime(v) - case MYSQL_TYPE_DATETIME: + case mysql.MYSQL_TYPE_DATETIME: n = 8 i64 := binary.LittleEndian.Uint64(data) if i64 == 0 { @@ -1306,29 +1306,29 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16, isPartial boo Dec: 0, }) } - case MYSQL_TYPE_DATETIME2: + case mysql.MYSQL_TYPE_DATETIME2: v, n, err = decodeDatetime2(data, meta) v = e.parseFracTime(v) - case MYSQL_TYPE_TIME: + case mysql.MYSQL_TYPE_TIME: n = 3 - i32 := uint32(FixedLengthInt(data[0:3])) + i32 := uint32(mysql.FixedLengthInt(data[0:3])) if i32 == 0 { v = "00:00:00" } else { v = fmt.Sprintf("%02d:%02d:%02d", i32/10000, (i32%10000)/100, i32%100) } - case MYSQL_TYPE_TIME2: + case mysql.MYSQL_TYPE_TIME2: v, n, err = decodeTime2(data, meta) - case MYSQL_TYPE_DATE: + case mysql.MYSQL_TYPE_DATE: n = 3 - i32 := uint32(FixedLengthInt(data[0:3])) + i32 := uint32(mysql.FixedLengthInt(data[0:3])) if i32 == 0 { v = "0000-00-00" } else { v = fmt.Sprintf("%04d-%02d-%02d", i32/(16*32), i32/32%16, i32%32) } - case MYSQL_TYPE_YEAR: + case mysql.MYSQL_TYPE_YEAR: n = 1 year := int(data[0]) if year == 0 { @@ -1336,7 +1336,7 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16, isPartial boo } else { v = year + 1900 } - case MYSQL_TYPE_ENUM: + case mysql.MYSQL_TYPE_ENUM: l := meta & 0xFF switch l { case 1: @@ -1348,22 +1348,22 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16, isPartial boo default: err = fmt.Errorf("Unknown ENUM packlen=%d", l) } - case MYSQL_TYPE_SET: + case mysql.MYSQL_TYPE_SET: n = int(meta & 0xFF) nbits := n * 8 v, err = littleDecodeBit(data, nbits, n) - case MYSQL_TYPE_BLOB: + case mysql.MYSQL_TYPE_BLOB: v, n, err = decodeBlob(data, meta) - case MYSQL_TYPE_VARCHAR, - MYSQL_TYPE_VAR_STRING: + case mysql.MYSQL_TYPE_VARCHAR, + mysql.MYSQL_TYPE_VAR_STRING: length = int(meta) v, n = decodeString(data, length) - case MYSQL_TYPE_STRING: + case mysql.MYSQL_TYPE_STRING: v, n = decodeString(data, length) - case MYSQL_TYPE_JSON: + case mysql.MYSQL_TYPE_JSON: // Refer: https://github.com/shyiko/mysql-binlog-connector-java/blob/master/src/main/java/com/github/shyiko/mysql/binlog/event/deserialization/AbstractRowsEventDataDeserializer.java#L404 - length = int(FixedLengthInt(data[0:meta])) + length = int(mysql.FixedLengthInt(data[0:meta])) n = length + int(meta) /* @@ -1397,7 +1397,7 @@ func (e *RowsEvent) decodeValue(data []byte, tp byte, meta uint16, isPartial boo } } } - case MYSQL_TYPE_GEOMETRY: + case mysql.MYSQL_TYPE_GEOMETRY: // MySQL saves Geometry as Blob in binlog // Seem that the binary format is SRID (4 bytes) + WKB, outer can use // MySQL GeoFromWKB or others to create the geometry data. @@ -1547,15 +1547,15 @@ func decodeBit(data []byte, nbits int, length int) (value int64, err error) { case 2: value = int64(binary.BigEndian.Uint16(data)) case 3: - value = int64(BFixedLengthInt(data[0:3])) + value = int64(mysql.BFixedLengthInt(data[0:3])) case 4: value = int64(binary.BigEndian.Uint32(data)) case 5: - value = int64(BFixedLengthInt(data[0:5])) + value = int64(mysql.BFixedLengthInt(data[0:5])) case 6: - value = int64(BFixedLengthInt(data[0:6])) + value = int64(mysql.BFixedLengthInt(data[0:6])) case 7: - value = int64(BFixedLengthInt(data[0:7])) + value = int64(mysql.BFixedLengthInt(data[0:7])) case 8: value = int64(binary.BigEndian.Uint64(data)) default: @@ -1579,15 +1579,15 @@ func littleDecodeBit(data []byte, nbits int, length int) (value int64, err error case 2: value = int64(binary.LittleEndian.Uint16(data)) case 3: - value = int64(FixedLengthInt(data[0:3])) + value = int64(mysql.FixedLengthInt(data[0:3])) case 4: value = int64(binary.LittleEndian.Uint32(data)) case 5: - value = int64(FixedLengthInt(data[0:5])) + value = int64(mysql.FixedLengthInt(data[0:5])) case 6: - value = int64(FixedLengthInt(data[0:6])) + value = int64(mysql.FixedLengthInt(data[0:6])) case 7: - value = int64(FixedLengthInt(data[0:7])) + value = int64(mysql.FixedLengthInt(data[0:7])) case 8: value = int64(binary.LittleEndian.Uint64(data)) default: @@ -1614,7 +1614,7 @@ func decodeTimestamp2(data []byte, dec uint16, timestampStringLocation *time.Loc case 3, 4: usec = int64(binary.BigEndian.Uint16(data[4:])) * 100 case 5, 6: - usec = int64(BFixedLengthInt(data[4:7])) + usec = int64(mysql.BFixedLengthInt(data[4:7])) } if sec == 0 { @@ -1634,7 +1634,7 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { // get datetime binary length n := int(5 + (dec+1)/2) - intPart := int64(BFixedLengthInt(data[0:5])) - DATETIMEF_INT_OFS + intPart := int64(mysql.BFixedLengthInt(data[0:5])) - DATETIMEF_INT_OFS var frac int64 = 0 switch dec { @@ -1643,7 +1643,7 @@ func decodeDatetime2(data []byte, dec uint16) (interface{}, int, error) { case 3, 4: frac = int64(binary.BigEndian.Uint16(data[5:7])) * 100 case 5, 6: - frac = int64(BFixedLengthInt(data[5:8])) + frac = int64(mysql.BFixedLengthInt(data[5:8])) } if intPart == 0 { @@ -1702,7 +1702,7 @@ func decodeTime2(data []byte, dec uint16) (string, int, error) { frac := int64(0) switch dec { case 1, 2: - intPart = int64(BFixedLengthInt(data[0:3])) - TIMEF_INT_OFS + intPart = int64(mysql.BFixedLengthInt(data[0:3])) - TIMEF_INT_OFS frac = int64(data[3]) if intPart < 0 && frac != 0 { /* @@ -1727,7 +1727,7 @@ func decodeTime2(data []byte, dec uint16) (string, int, error) { } tmp = intPart<<24 + frac*10000 case 3, 4: - intPart = int64(BFixedLengthInt(data[0:3])) - TIMEF_INT_OFS + intPart = int64(mysql.BFixedLengthInt(data[0:3])) - TIMEF_INT_OFS frac = int64(binary.BigEndian.Uint16(data[3:5])) if intPart < 0 && frac != 0 { /* @@ -1740,10 +1740,10 @@ func decodeTime2(data []byte, dec uint16) (string, int, error) { tmp = intPart<<24 + frac*100 case 5, 6: - tmp = int64(BFixedLengthInt(data[0:6])) - TIMEF_OFS + tmp = int64(mysql.BFixedLengthInt(data[0:6])) - TIMEF_OFS return timeFormat(tmp, dec, n) default: - intPart = int64(BFixedLengthInt(data[0:3])) - TIMEF_INT_OFS + intPart = int64(mysql.BFixedLengthInt(data[0:3])) - TIMEF_INT_OFS tmp = intPart << 24 } @@ -1789,7 +1789,7 @@ func decodeBlob(data []byte, meta uint16) (v []byte, n int, err error) { v = data[2 : 2+length] n = length + 2 case 3: - length = int(FixedLengthInt(data[0:3])) + length = int(mysql.FixedLengthInt(data[0:3])) v = data[3 : 3+length] n = length + 3 case 4: diff --git a/replication/transaction_payload_event.go b/replication/transaction_payload_event.go index 58826143d..64bc0dc3c 100644 --- a/replication/transaction_payload_event.go +++ b/replication/transaction_payload_event.go @@ -6,9 +6,8 @@ import ( "fmt" "io" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/klauspost/compress/zstd" - - . "github.com/go-mysql-org/go-mysql/mysql" ) // On The Wire: Field Types @@ -72,23 +71,23 @@ func (e *TransactionPayloadEvent) decodeFields(data []byte) error { offset := uint64(0) for { - fieldType := FixedLengthInt(data[offset : offset+1]) + fieldType := mysql.FixedLengthInt(data[offset : offset+1]) offset++ if fieldType == OTW_PAYLOAD_HEADER_END_MARK { e.Payload = data[offset:] break } else { - fieldLength := FixedLengthInt(data[offset : offset+1]) + fieldLength := mysql.FixedLengthInt(data[offset : offset+1]) offset++ switch fieldType { case OTW_PAYLOAD_SIZE_FIELD: - e.Size = FixedLengthInt(data[offset : offset+fieldLength]) + e.Size = mysql.FixedLengthInt(data[offset : offset+fieldLength]) case OTW_PAYLOAD_COMPRESSION_TYPE_FIELD: - e.CompressionType = FixedLengthInt(data[offset : offset+fieldLength]) + e.CompressionType = mysql.FixedLengthInt(data[offset : offset+fieldLength]) case OTW_PAYLOAD_UNCOMPRESSED_SIZE_FIELD: - e.UncompressedSize = FixedLengthInt(data[offset : offset+fieldLength]) + e.UncompressedSize = mysql.FixedLengthInt(data[offset : offset+fieldLength]) } offset += fieldLength diff --git a/server/auth.go b/server/auth.go index e2852a385..0fcf8f1ee 100644 --- a/server/auth.go +++ b/server/auth.go @@ -9,7 +9,7 @@ import ( "crypto/tls" "fmt" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/errors" ) @@ -20,13 +20,13 @@ var ( func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) error { switch authPluginName { - case AUTH_NATIVE_PASSWORD: + case mysql.AUTH_NATIVE_PASSWORD: if err := c.acquirePassword(); err != nil { return err } return c.compareNativePasswordAuthData(clientAuthData, c.password) - case AUTH_CACHING_SHA2_PASSWORD: + case mysql.AUTH_CACHING_SHA2_PASSWORD: if err := c.compareCacheSha2PasswordAuthData(clientAuthData); err != nil { return err } @@ -35,7 +35,7 @@ func (c *Conn) compareAuthData(authPluginName string, clientAuthData []byte) err } return nil - case AUTH_SHA256_PASSWORD: + case mysql.AUTH_SHA256_PASSWORD: if err := c.acquirePassword(); err != nil { return err } @@ -59,7 +59,7 @@ func (c *Conn) acquirePassword() error { return err } if !found { - return NewDefaultError(ER_NO_SUCH_USER, c.user, c.RemoteAddr().String()) + return mysql.NewDefaultError(mysql.ER_NO_SUCH_USER, c.user, c.RemoteAddr().String()) } c.password = password return nil @@ -94,7 +94,7 @@ func scrambleValidation(cached, nonce, scramble []byte) bool { } func (c *Conn) compareNativePasswordAuthData(clientAuthData []byte, password string) error { - if bytes.Equal(CalcPassword(c.salt, []byte(password)), clientAuthData) { + if bytes.Equal(mysql.CalcPassword(c.salt, []byte(password)), clientAuthData) { return nil } return errAccessDenied(password) @@ -160,7 +160,7 @@ func (c *Conn) compareCacheSha2PasswordAuthData(clientAuthData []byte) error { if err := c.acquirePassword(); err != nil { return err } - if bytes.Equal(CalcCachingSha2Password(c.salt, c.password), clientAuthData) { + if bytes.Equal(mysql.CalcCachingSha2Password(c.salt, c.password), clientAuthData) { // 'fast' auth: write "More data" packet (first byte == 0x01) with the second byte = 0x03 return c.writeAuthMoreDataFastAuth() } diff --git a/server/auth_switch_response.go b/server/auth_switch_response.go index 835cf7bab..b11e1437f 100644 --- a/server/auth_switch_response.go +++ b/server/auth_switch_response.go @@ -9,7 +9,7 @@ import ( "crypto/tls" "fmt" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/errors" ) @@ -20,13 +20,13 @@ func (c *Conn) handleAuthSwitchResponse() error { } switch c.authPluginName { - case AUTH_NATIVE_PASSWORD: + case mysql.AUTH_NATIVE_PASSWORD: if err := c.acquirePassword(); err != nil { return err } return c.compareNativePasswordAuthData(authData, c.password) - case AUTH_CACHING_SHA2_PASSWORD: + case mysql.AUTH_CACHING_SHA2_PASSWORD: if !c.cachingSha2FullAuth { // Switched auth method but no MoreData packet send yet if err := c.compareCacheSha2PasswordAuthData(authData); err != nil { @@ -45,7 +45,7 @@ func (c *Conn) handleAuthSwitchResponse() error { c.writeCachingSha2Cache() return nil - case AUTH_SHA256_PASSWORD: + case mysql.AUTH_SHA256_PASSWORD: cont, err := c.handlePublicKeyRetrieval(authData) if err != nil { return err diff --git a/server/command.go b/server/command.go index adf840718..27aa6bf9b 100644 --- a/server/command.go +++ b/server/command.go @@ -6,7 +6,6 @@ import ( "log" "github.com/go-mysql-org/go-mysql/mysql" - . "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/replication" "github.com/go-mysql-org/go-mysql/utils" ) @@ -17,15 +16,15 @@ type Handler interface { UseDB(dbName string) error //handle COM_QUERY command, like SELECT, INSERT, UPDATE, etc... //If Result has a Resultset (SELECT, SHOW, etc...), we will send this as the response, otherwise, we will send Result - HandleQuery(query string) (*Result, error) + HandleQuery(query string) (*mysql.Result, error) //handle COM_FILED_LIST command - HandleFieldList(table string, fieldWildcard string) ([]*Field, error) + HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) //handle COM_STMT_PREPARE, params is the param number for this statement, columns is the column number //context will be used later for statement execute HandleStmtPrepare(query string) (params int, columns int, context interface{}, err error) //handle COM_STMT_EXECUTE, context is the previous one set in prepare //query is the statement prepare query, and args is the params for this statement - HandleStmtExecute(context interface{}, query string, args []interface{}) (*Result, error) + HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) //handle COM_STMT_CLOSE, context is the previous one set in prepare //this handler has no response HandleStmtClose(context interface{}) error @@ -38,8 +37,8 @@ type Handler interface { type ReplicationHandler interface { // handle Replication command HandleRegisterSlave(data []byte) error - HandleBinlogDump(pos Position) (*replication.BinlogStreamer, error) - HandleBinlogDumpGTID(gtidSet *MysqlGTIDSet) (*replication.BinlogStreamer, error) + HandleBinlogDump(pos mysql.Position) (*replication.BinlogStreamer, error) + HandleBinlogDumpGTID(gtidSet *mysql.MysqlGTIDSet) (*replication.BinlogStreamer, error) } // HandleCommand is handling commands received by the server @@ -76,25 +75,25 @@ func (c *Conn) dispatch(data []byte) interface{} { data = data[1:] switch cmd { - case COM_QUIT: + case mysql.COM_QUIT: c.Close() c.Conn = nil return noResponse{} - case COM_QUERY: + case mysql.COM_QUERY: if r, err := c.h.HandleQuery(utils.ByteSliceToString(data)); err != nil { return err } else { return r } - case COM_PING: + case mysql.COM_PING: return nil - case COM_INIT_DB: + case mysql.COM_INIT_DB: if err := c.h.UseDB(utils.ByteSliceToString(data)); err != nil { return err } else { return nil } - case COM_FIELD_LIST: + case mysql.COM_FIELD_LIST: index := bytes.IndexByte(data, 0x00) table := utils.ByteSliceToString(data[0:index]) wildcard := utils.ByteSliceToString(data[index+1:]) @@ -104,7 +103,7 @@ func (c *Conn) dispatch(data []byte) interface{} { } else { return fs } - case COM_STMT_PREPARE: + case mysql.COM_STMT_PREPARE: c.stmtID++ st := new(Stmt) st.ID = c.stmtID @@ -117,41 +116,41 @@ func (c *Conn) dispatch(data []byte) interface{} { c.stmts[c.stmtID] = st return st } - case COM_STMT_EXECUTE: + case mysql.COM_STMT_EXECUTE: if r, err := c.handleStmtExecute(data); err != nil { return err } else { return r } - case COM_STMT_CLOSE: + case mysql.COM_STMT_CLOSE: if err := c.handleStmtClose(data); err != nil { return err } return noResponse{} - case COM_STMT_SEND_LONG_DATA: + case mysql.COM_STMT_SEND_LONG_DATA: if err := c.handleStmtSendLongData(data); err != nil { return err } return noResponse{} - case COM_STMT_RESET: + case mysql.COM_STMT_RESET: if r, err := c.handleStmtReset(data); err != nil { return err } else { return r } - case COM_SET_OPTION: + case mysql.COM_SET_OPTION: if err := c.h.HandleOtherCommand(cmd, data); err != nil { return err } return eofResponse{} - case COM_REGISTER_SLAVE: + case mysql.COM_REGISTER_SLAVE: if h, ok := c.h.(ReplicationHandler); ok { return h.HandleRegisterSlave(data) } else { return c.h.HandleOtherCommand(cmd, data) } - case COM_BINLOG_DUMP: + case mysql.COM_BINLOG_DUMP: if h, ok := c.h.(ReplicationHandler); ok { pos, err := parseBinlogDump(data) if err != nil { @@ -165,7 +164,7 @@ func (c *Conn) dispatch(data []byte) interface{} { } else { return c.h.HandleOtherCommand(cmd, data) } - case COM_BINLOG_DUMP_GTID: + case mysql.COM_BINLOG_DUMP_GTID: if h, ok := c.h.(ReplicationHandler); ok { gtidSet, err := parseBinlogDumpGTID(data) if err != nil { @@ -200,7 +199,7 @@ func (h EmptyHandler) UseDB(dbName string) error { } // HandleQuery is called for COM_QUERY -func (h EmptyHandler) HandleQuery(query string) (*Result, error) { +func (h EmptyHandler) HandleQuery(query string) (*mysql.Result, error) { log.Printf("Received: Query: %s", query) // These two queries are implemented for minimal support for MySQL Shell @@ -223,7 +222,7 @@ func (h EmptyHandler) HandleQuery(query string) (*Result, error) { // HandleFieldList is called for COM_FIELD_LIST packets // Note that COM_FIELD_LIST has been deprecated since MySQL 5.7.11 // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_field_list.html -func (h EmptyHandler) HandleFieldList(table string, fieldWildcard string) ([]*Field, error) { +func (h EmptyHandler) HandleFieldList(table string, fieldWildcard string) ([]*mysql.Field, error) { log.Printf("Received: FieldList: table=%s, fieldWildcard:%s", table, fieldWildcard) return nil, fmt.Errorf("not supported now") } @@ -239,7 +238,7 @@ func (h EmptyHandler) HandleStmtPrepare(query string) (int, int, interface{}, er //revive:disable:unused-parameter // HandleStmtExecute is called for COM_STMT_EXECUTE -func (h EmptyHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*Result, error) { +func (h EmptyHandler) HandleStmtExecute(context interface{}, query string, args []interface{}) (*mysql.Result, error) { log.Printf("Received: StmtExecute: %s (args: %v)", query, args) return nil, fmt.Errorf("not supported now") } @@ -259,13 +258,13 @@ func (h EmptyReplicationHandler) HandleRegisterSlave(data []byte) error { } // HandleBinlogDump is called for COM_BINLOG_DUMP (non-GTID) -func (h EmptyReplicationHandler) HandleBinlogDump(pos Position) (*replication.BinlogStreamer, error) { +func (h EmptyReplicationHandler) HandleBinlogDump(pos mysql.Position) (*replication.BinlogStreamer, error) { log.Printf("Received: BinlogDump: pos=%s", pos.String()) return nil, fmt.Errorf("not supported now") } // HandleBinlogDumpGTID is called for COM_BINLOG_DUMP_GTID -func (h EmptyReplicationHandler) HandleBinlogDumpGTID(gtidSet *MysqlGTIDSet) (*replication.BinlogStreamer, error) { +func (h EmptyReplicationHandler) HandleBinlogDumpGTID(gtidSet *mysql.MysqlGTIDSet) (*replication.BinlogStreamer, error) { log.Printf("Received: BinlogDumpGTID: gtidSet=%s", gtidSet.String()) return nil, fmt.Errorf("not supported now") } @@ -273,8 +272,8 @@ func (h EmptyReplicationHandler) HandleBinlogDumpGTID(gtidSet *MysqlGTIDSet) (*r // HandleOtherCommand is called for commands not handled elsewhere func (h EmptyHandler) HandleOtherCommand(cmd byte, data []byte) error { log.Printf("Received: OtherCommand: cmd=%x, data=%x", cmd, data) - return NewError( - ER_UNKNOWN_ERROR, + return mysql.NewError( + mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("command %d is not supported now", cmd), ) } diff --git a/server/conn.go b/server/conn.go index b3987ff88..38769b71d 100644 --- a/server/conn.go +++ b/server/conn.go @@ -5,7 +5,7 @@ import ( "net" "sync/atomic" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/packet" ) @@ -57,7 +57,7 @@ func NewConn(conn net.Conn, user string, password string, h Handler) (*Conn, err h: h, connectionID: atomic.AddUint32(&baseConnID, 1), stmts: make(map[uint32]*Stmt), - salt: RandomBuf(20), + salt: mysql.RandomBuf(20), } c.closed.Store(false) @@ -85,7 +85,7 @@ func NewCustomizedConn(conn net.Conn, serverConf *Server, p CredentialProvider, h: h, connectionID: atomic.AddUint32(&baseConnID, 1), stmts: make(map[uint32]*Stmt), - salt: RandomBuf(20), + salt: mysql.RandomBuf(20), } c.closed.Store(false) @@ -104,11 +104,12 @@ func (c *Conn) handshake() error { if err := c.readHandshakeResponse(); err != nil { if errors.Is(err, ErrAccessDenied) { - var usingPasswd uint16 = ER_YES + var usingPasswd uint16 = mysql.ER_YES if errors.Is(err, ErrAccessDeniedNoPassword) { - usingPasswd = ER_NO + usingPasswd = mysql.ER_NO } - err = NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.RemoteAddr().String(), MySQLErrName[usingPasswd]) + err = mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR, c.user, + c.RemoteAddr().String(), mysql.MySQLErrName[usingPasswd]) } _ = c.writeError(err) return err @@ -167,19 +168,19 @@ func (c *Conn) ConnectionID() uint32 { } func (c *Conn) IsAutoCommit() bool { - return c.HasStatus(SERVER_STATUS_AUTOCOMMIT) + return c.HasStatus(mysql.SERVER_STATUS_AUTOCOMMIT) } func (c *Conn) IsInTransaction() bool { - return c.HasStatus(SERVER_STATUS_IN_TRANS) + return c.HasStatus(mysql.SERVER_STATUS_IN_TRANS) } func (c *Conn) SetInTransaction() { - c.SetStatus(SERVER_STATUS_IN_TRANS) + c.SetStatus(mysql.SERVER_STATUS_IN_TRANS) } func (c *Conn) ClearInTransaction() { - c.UnsetStatus(SERVER_STATUS_IN_TRANS) + c.UnsetStatus(mysql.SERVER_STATUS_IN_TRANS) } func (c *Conn) SetStatus(status uint16) { diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 762843865..ab904cf9e 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -5,7 +5,7 @@ import ( "crypto/tls" "encoding/binary" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/errors" ) @@ -39,7 +39,7 @@ func (c *Conn) readHandshakeResponse() error { } // read connection attributes - if c.capability&CLIENT_CONNECT_ATTRS > 0 { + if c.capability&mysql.CLIENT_CONNECT_ATTRS > 0 { // readAttributes returns new position for further processing of data _, err = c.readAttributes(data, pos) if err != nil { @@ -64,18 +64,18 @@ func (c *Conn) decodeFirstPart(data []byte) (newData []byte, pos int, err error) // prevent 'panic: runtime error: index out of range' error defer func() { if recover() != nil { - err = NewDefaultError(ER_HANDSHAKE_ERROR) + err = mysql.NewDefaultError(mysql.ER_HANDSHAKE_ERROR) } }() // check CLIENT_PROTOCOL_41 - if uint32(binary.LittleEndian.Uint16(data[:2]))&CLIENT_PROTOCOL_41 == 0 { + if uint32(binary.LittleEndian.Uint16(data[:2]))&mysql.CLIENT_PROTOCOL_41 == 0 { return nil, 0, errors.New("CLIENT_PROTOCOL_41 compatible client is required") } //capability c.capability = binary.LittleEndian.Uint32(data[:4]) - if c.capability&CLIENT_SECURE_CONNECTION == 0 { + if c.capability&mysql.CLIENT_SECURE_CONNECTION == 0 { return nil, 0, errors.New("CLIENT_SECURE_CONNECTION compatible client is required") } pos += 4 @@ -92,7 +92,7 @@ func (c *Conn) decodeFirstPart(data []byte) (newData []byte, pos int, err error) // is this a SSLRequest packet? if len(data) == (4 + 4 + 1 + 23) { - if c.serverConf.capability&CLIENT_SSL == 0 { + if c.serverConf.capability&mysql.CLIENT_SSL == 0 { return nil, 0, errors.Errorf("The host '%s' does not support SSL connections", c.RemoteAddr().String()) } // switch to TLS @@ -117,7 +117,7 @@ func (c *Conn) readUserName(data []byte, pos int) (int, error) { } func (c *Conn) readDb(data []byte, pos int) (int, error) { - if c.capability&CLIENT_CONNECT_WITH_DB != 0 { + if c.capability&mysql.CLIENT_CONNECT_WITH_DB != 0 { if len(data[pos:]) == 0 { return pos, nil } @@ -133,13 +133,13 @@ func (c *Conn) readDb(data []byte, pos int) (int, error) { } func (c *Conn) readPluginName(data []byte, pos int) int { - if c.capability&CLIENT_PLUGIN_AUTH != 0 { + if c.capability&mysql.CLIENT_PLUGIN_AUTH != 0 { c.authPluginName = string(data[pos : pos+bytes.IndexByte(data[pos:], 0x00)]) pos += len(c.authPluginName) + 1 } else { // The method used is Native Authentication if both CLIENT_PROTOCOL_41 and CLIENT_SECURE_CONNECTION are set, // but CLIENT_PLUGIN_AUTH is not set, so we fallback to 'mysql_native_password' - c.authPluginName = AUTH_NATIVE_PASSWORD + c.authPluginName = mysql.AUTH_NATIVE_PASSWORD } return pos } @@ -148,23 +148,24 @@ func (c *Conn) readAuthData(data []byte, pos int) (auth []byte, authLen int, new // prevent 'panic: runtime error: index out of range' error defer func() { if recover() != nil { - err = NewDefaultError(ER_HANDSHAKE_ERROR) + err = mysql.NewDefaultError(mysql.ER_HANDSHAKE_ERROR) } }() // length encoded data - if c.capability&CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 { - authData, isNULL, readBytes, err := LengthEncodedString(data[pos:]) + if c.capability&mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA != 0 { + authData, isNULL, readBytes, err := mysql.LengthEncodedString(data[pos:]) if err != nil { return nil, 0, 0, err } if isNULL { // no auth length and no auth data, just \NUL, considered invalid auth data, and reject connection as MySQL does - return nil, 0, 0, NewDefaultError(ER_ACCESS_DENIED_ERROR, c.user, c.RemoteAddr().String(), MySQLErrName[ER_NO]) + return nil, 0, 0, mysql.NewDefaultError(mysql.ER_ACCESS_DENIED_ERROR, + c.user, c.RemoteAddr().String(), mysql.MySQLErrName[mysql.ER_NO]) } auth = authData authLen = readBytes - } else if c.capability&CLIENT_SECURE_CONNECTION != 0 { + } else if c.capability&mysql.CLIENT_SECURE_CONNECTION != 0 { // auth length and auth authLen = int(data[pos]) pos++ @@ -183,8 +184,8 @@ func (c *Conn) readAuthData(data []byte, pos int) (auth []byte, authLen int, new func (c *Conn) handlePublicKeyRetrieval(authData []byte) (bool, error) { // if the client use 'sha256_password' auth method, and request for a public key // we send back a keyfile with Protocol::AuthMoreData - if c.authPluginName == AUTH_SHA256_PASSWORD && len(authData) == 1 && authData[0] == 0x01 { - if c.serverConf.capability&CLIENT_SSL == 0 { + if c.authPluginName == mysql.AUTH_SHA256_PASSWORD && len(authData) == 1 && authData[0] == 0x01 { + if c.serverConf.capability&mysql.CLIENT_SSL == 0 { return false, errors.New("server does not support SSL: CLIENT_SSL not enabled") } if err := c.writeAuthMoreDataPubkey(); err != nil { @@ -213,7 +214,7 @@ func (c *Conn) handleAuthMatch() (bool, error) { func (c *Conn) readAttributes(data []byte, pos int) (int, error) { // read length of attribute data - attrLen, isNull, skip := LengthEncodedInt(data[pos:]) + attrLen, isNull, skip := mysql.LengthEncodedInt(data[pos:]) pos += skip if isNull { return pos, nil @@ -229,7 +230,7 @@ func (c *Conn) readAttributes(data []byte, pos int) (int, error) { // read until end of data or NUL for atrribute key/values for { - str, isNull, strLen, err := LengthEncodedString(data[pos:]) + str, isNull, strLen, err := mysql.LengthEncodedString(data[pos:]) if err != nil { return -1, err } diff --git a/server/resp.go b/server/resp.go index fd14ee54d..18eb53c43 100644 --- a/server/resp.go +++ b/server/resp.go @@ -4,25 +4,25 @@ import ( "context" "fmt" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/replication" ) -func (c *Conn) writeOK(r *Result) error { +func (c *Conn) writeOK(r *mysql.Result) error { if r == nil { - r = NewResultReserveResultset(0) + r = mysql.NewResultReserveResultset(0) } r.Status |= c.status data := make([]byte, 4, 32) - data = append(data, OK_HEADER) + data = append(data, mysql.OK_HEADER) - data = append(data, PutLengthEncodedInt(r.AffectedRows)...) - data = append(data, PutLengthEncodedInt(r.InsertId)...) + data = append(data, mysql.PutLengthEncodedInt(r.AffectedRows)...) + data = append(data, mysql.PutLengthEncodedInt(r.InsertId)...) - if c.capability&CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { data = append(data, byte(r.Status), byte(r.Status>>8)) data = append(data, byte(r.Warnings), byte(r.Warnings>>8)) } @@ -31,18 +31,18 @@ func (c *Conn) writeOK(r *Result) error { } func (c *Conn) writeError(e error) error { - var m *MyError + var m *mysql.MyError var ok bool - if m, ok = e.(*MyError); !ok { - m = NewError(ER_UNKNOWN_ERROR, e.Error()) + if m, ok = e.(*mysql.MyError); !ok { + m = mysql.NewError(mysql.ER_UNKNOWN_ERROR, e.Error()) } data := make([]byte, 4, 16+len(m.Message)) - data = append(data, ERR_HEADER) + data = append(data, mysql.ERR_HEADER) data = append(data, byte(m.Code), byte(m.Code>>8)) - if c.capability&CLIENT_PROTOCOL_41 > 0 { + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { data = append(data, '#') data = append(data, m.State...) } @@ -55,8 +55,8 @@ func (c *Conn) writeError(e error) error { func (c *Conn) writeEOF() error { data := make([]byte, 4, 9) - data = append(data, EOF_HEADER) - if c.capability&CLIENT_PROTOCOL_41 > 0 { + data = append(data, mysql.EOF_HEADER) + if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 { data = append(data, byte(c.warnings), byte(c.warnings>>8)) data = append(data, byte(c.status), byte(c.status>>8)) } @@ -67,11 +67,11 @@ func (c *Conn) writeEOF() error { // see: https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_connection_phase_packets_protocol_auth_switch_request.html func (c *Conn) writeAuthSwitchRequest(newAuthPluginName string) error { data := make([]byte, 4) - data = append(data, EOF_HEADER) + data = append(data, mysql.EOF_HEADER) data = append(data, []byte(newAuthPluginName)...) data = append(data, 0x00) // new auth data - c.salt = RandomBuf(20) + c.salt = mysql.RandomBuf(20) data = append(data, c.salt...) // the online doc states it's a string.EOF, however, the actual MySQL server add a \NUL to the end, without it, the // official MySQL client will fail. @@ -94,26 +94,26 @@ func (c *Conn) readAuthSwitchRequestResponse() ([]byte, error) { func (c *Conn) writeAuthMoreDataPubkey() error { data := make([]byte, 4) - data = append(data, MORE_DATE_HEADER) + data = append(data, mysql.MORE_DATE_HEADER) data = append(data, c.serverConf.pubKey...) return c.WritePacket(data) } func (c *Conn) writeAuthMoreDataFullAuth() error { data := make([]byte, 4) - data = append(data, MORE_DATE_HEADER) - data = append(data, CACHE_SHA2_FULL_AUTH) + data = append(data, mysql.MORE_DATE_HEADER) + data = append(data, mysql.CACHE_SHA2_FULL_AUTH) return c.WritePacket(data) } func (c *Conn) writeAuthMoreDataFastAuth() error { data := make([]byte, 4) - data = append(data, MORE_DATE_HEADER) - data = append(data, CACHE_SHA2_FAST_AUTH) + data = append(data, mysql.MORE_DATE_HEADER) + data = append(data, mysql.CACHE_SHA2_FAST_AUTH) return c.WritePacket(data) } -func (c *Conn) writeResultset(r *Resultset) error { +func (c *Conn) writeResultset(r *mysql.Resultset) error { // for a streaming resultset, that handled rowdata separately in a callback // of type SelectPerRowCallback, we can suffice by ending the stream with // an EOF @@ -121,14 +121,14 @@ func (c *Conn) writeResultset(r *Resultset) error { // been taken care of already in the user-defined callback if r.StreamingDone { switch r.Streaming { - case StreamingMultiple: + case mysql.StreamingMultiple: return nil - case StreamingSelect: + case mysql.StreamingSelect: return c.writeEOF() } } - columnLen := PutLengthEncodedInt(uint64(len(r.Fields))) + columnLen := mysql.PutLengthEncodedInt(uint64(len(r.Fields))) data := make([]byte, 4, 1024) @@ -143,7 +143,7 @@ func (c *Conn) writeResultset(r *Resultset) error { // streaming select resultsets handle rowdata in a separate callback of type // SelectPerRowCallback so we're done here - if r.Streaming == StreamingSelect { + if r.Streaming == mysql.StreamingSelect { return nil } @@ -162,7 +162,7 @@ func (c *Conn) writeResultset(r *Resultset) error { return nil } -func (c *Conn) writeFieldList(fs []*Field, data []byte) error { +func (c *Conn) writeFieldList(fs []*mysql.Field, data []byte) error { if data == nil { data = make([]byte, 4, 1024) } @@ -181,18 +181,18 @@ func (c *Conn) writeFieldList(fs []*Field, data []byte) error { return nil } -func (c *Conn) writeFieldValues(fv []FieldValue) error { +func (c *Conn) writeFieldValues(fv []mysql.FieldValue) error { data := make([]byte, 4, 1024) for _, v := range fv { if v.Value() == nil { // NULL value is encoded as 0xfb here data = append(data, []byte{0xfb}...) } else { - tv, err := FormatTextValue(v.Value()) + tv, err := mysql.FormatTextValue(v.Value()) if err != nil { return err } - data = append(data, PutLengthEncodedString(tv)...) + data = append(data, mysql.PutLengthEncodedString(tv)...) } } @@ -207,7 +207,7 @@ func (c *Conn) writeBinlogEvents(s *replication.BinlogStreamer) error { return err } data := make([]byte, 4, 4+len(ev.RawData)) - data = append(data, OK_HEADER) + data = append(data, mysql.OK_HEADER) data = append(data, ev.RawData...) if err := c.WritePacket(data); err != nil { @@ -229,15 +229,15 @@ func (c *Conn) WriteValue(value interface{}) error { return c.writeError(v) case nil: return c.writeOK(nil) - case *Result: + case *mysql.Result: if v != nil && v.Resultset != nil { return c.writeResultset(v.Resultset) } else { return c.writeOK(v) } - case []*Field: + case []*mysql.Field: return c.writeFieldList(v, nil) - case []FieldValue: + case []mysql.FieldValue: return c.writeFieldValues(v) case *replication.BinlogStreamer: return c.writeBinlogEvents(v) diff --git a/server/server_conf.go b/server/server_conf.go index 4db71c6aa..87a74f370 100644 --- a/server/server_conf.go +++ b/server/server_conf.go @@ -5,7 +5,7 @@ import ( "fmt" "sync" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" ) var defaultServer = NewDefaultServer() @@ -50,11 +50,11 @@ func NewDefaultServer() *Server { return &Server{ serverVersion: "8.0.11", protocolVersion: 10, - capability: CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | - CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_SSL | - CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_ATTRS, - collationId: DEFAULT_COLLATION_ID, - defaultAuthMethod: AUTH_NATIVE_PASSWORD, + capability: mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_LONG_FLAG | mysql.CLIENT_CONNECT_WITH_DB | mysql.CLIENT_PROTOCOL_41 | + mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_SECURE_CONNECTION | mysql.CLIENT_PLUGIN_AUTH | mysql.CLIENT_SSL | + mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | mysql.CLIENT_CONNECT_ATTRS, + collationId: mysql.DEFAULT_COLLATION_ID, + defaultAuthMethod: mysql.AUTH_NATIVE_PASSWORD, pubKey: getPublicKeyFromCert(certPem), tlsConfig: tlsConf, cacheShaPassword: new(sync.Map), @@ -78,11 +78,11 @@ func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string //if !isAuthMethodAllowedByServer(defaultAuthMethod, allowedAuthMethods) { // panic(fmt.Sprintf("default auth method is not one of the allowed auth methods")) //} - var capFlag = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 | - CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_ATTRS | - CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA + var capFlag = mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_LONG_FLAG | mysql.CLIENT_CONNECT_WITH_DB | mysql.CLIENT_PROTOCOL_41 | + mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_SECURE_CONNECTION | mysql.CLIENT_PLUGIN_AUTH | mysql.CLIENT_CONNECT_ATTRS | + mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA if tlsConfig != nil { - capFlag |= CLIENT_SSL + capFlag |= mysql.CLIENT_SSL } return &Server{ serverVersion: serverVersion, @@ -97,7 +97,7 @@ func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string } func isAuthMethodSupported(authMethod string) bool { - return authMethod == AUTH_NATIVE_PASSWORD || authMethod == AUTH_CACHING_SHA2_PASSWORD || authMethod == AUTH_SHA256_PASSWORD + return authMethod == mysql.AUTH_NATIVE_PASSWORD || authMethod == mysql.AUTH_CACHING_SHA2_PASSWORD || authMethod == mysql.AUTH_SHA256_PASSWORD } func (s *Server) InvalidateCache(username string, host string) { diff --git a/server/stmt.go b/server/stmt.go index be79ad3ad..56c80c507 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -6,13 +6,13 @@ import ( "math" "strconv" - . "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/mysql" "github.com/pingcap/errors" ) var ( - paramFieldData = (&Field{Name: []byte("?")}).Dump() - columnFieldData = (&Field{}).Dump() + paramFieldData = (&mysql.Field{Name: []byte("?")}).Dump() + columnFieldData = (&mysql.Field{}).Dump() ) type Stmt struct { @@ -44,11 +44,11 @@ func (c *Conn) writePrepare(s *Stmt) error { //status ok data = append(data, 0) //stmt id - data = append(data, Uint32ToBytes(s.ID)...) + data = append(data, mysql.Uint32ToBytes(s.ID)...) //number columns - data = append(data, Uint16ToBytes(uint16(s.Columns))...) + data = append(data, mysql.Uint16ToBytes(uint16(s.Columns))...) //number params - data = append(data, Uint16ToBytes(uint16(s.Params))...) + data = append(data, mysql.Uint16ToBytes(uint16(s.Params))...) //filter [00] data = append(data, 0) //warning count @@ -90,9 +90,9 @@ func (c *Conn) writePrepare(s *Stmt) error { return nil } -func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { +func (c *Conn) handleStmtExecute(data []byte) (*mysql.Result, error) { if len(data) < 9 { - return nil, ErrMalformPacket + return nil, mysql.ErrMalformPacket } pos := 0 @@ -101,7 +101,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { s, ok := c.stmts[id] if !ok { - return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 5, + return nil, mysql.NewDefaultError(mysql.ER_UNKNOWN_STMT_HANDLER, 5, strconv.FormatUint(uint64(id), 10), "stmt_execute") } @@ -113,18 +113,18 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { // Make sure the first 4 bits are 0. if flag>>4 != 0 { - return nil, NewError(ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flags 0x%x", flag)) + return nil, mysql.NewError(mysql.ER_UNKNOWN_ERROR, fmt.Sprintf("unsupported flags 0x%x", flag)) } // Test for unsupported flags in the remaining 4 bits. - if flag&CURSOR_TYPE_READ_ONLY > 0 { - return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_READ_ONLY") + if flag&mysql.CURSOR_TYPE_READ_ONLY > 0 { + return nil, mysql.NewError(mysql.ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_READ_ONLY") } - if flag&CURSOR_TYPE_FOR_UPDATE > 0 { - return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_FOR_UPDATE") + if flag&mysql.CURSOR_TYPE_FOR_UPDATE > 0 { + return nil, mysql.NewError(mysql.ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_FOR_UPDATE") } - if flag&CURSOR_TYPE_SCROLLABLE > 0 { - return nil, NewError(ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_SCROLLABLE") + if flag&mysql.CURSOR_TYPE_SCROLLABLE > 0 { + return nil, mysql.NewError(mysql.ER_UNKNOWN_ERROR, "unsupported flag CURSOR_TYPE_SCROLLABLE") } //skip iteration-count, always 1 @@ -139,7 +139,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { if paramNum > 0 { nullBitmapLen := (s.Params + 7) >> 3 if len(data) < (pos + nullBitmapLen + 1) { - return nil, ErrMalformPacket + return nil, mysql.ErrMalformPacket } nullBitmaps = data[pos : pos+nullBitmapLen] pos += nullBitmapLen @@ -148,7 +148,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { if data[pos] == 1 { pos++ if len(data) < (pos + (paramNum << 1)) { - return nil, ErrMalformPacket + return nil, mysql.ErrMalformPacket } paramTypes = data[pos : pos+(paramNum<<1)] @@ -162,7 +162,7 @@ func (c *Conn) handleStmtExecute(data []byte) (*Result, error) { } } - var r *Result + var r *mysql.Result var err error if r, err = c.h.HandleStmtExecute(s.Context, s.Query, s.Args); err != nil { return nil, errors.Trace(err) @@ -193,13 +193,13 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0 switch tp { - case MYSQL_TYPE_NULL: + case mysql.MYSQL_TYPE_NULL: args[i] = nil continue - case MYSQL_TYPE_TINY: + case mysql.MYSQL_TYPE_TINY: if len(paramValues) < (pos + 1) { - return ErrMalformPacket + return mysql.ErrMalformPacket } if isUnsigned { @@ -211,9 +211,9 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) pos++ continue - case MYSQL_TYPE_SHORT, MYSQL_TYPE_YEAR: + case mysql.MYSQL_TYPE_SHORT, mysql.MYSQL_TYPE_YEAR: if len(paramValues) < (pos + 2) { - return ErrMalformPacket + return mysql.ErrMalformPacket } if isUnsigned { @@ -224,9 +224,9 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) pos += 2 continue - case MYSQL_TYPE_INT24, MYSQL_TYPE_LONG: + case mysql.MYSQL_TYPE_INT24, mysql.MYSQL_TYPE_LONG: if len(paramValues) < (pos + 4) { - return ErrMalformPacket + return mysql.ErrMalformPacket } if isUnsigned { @@ -237,9 +237,9 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) pos += 4 continue - case MYSQL_TYPE_LONGLONG: + case mysql.MYSQL_TYPE_LONGLONG: if len(paramValues) < (pos + 8) { - return ErrMalformPacket + return mysql.ErrMalformPacket } if isUnsigned { @@ -250,35 +250,35 @@ func (c *Conn) bindStmtArgs(s *Stmt, nullBitmap, paramTypes, paramValues []byte) pos += 8 continue - case MYSQL_TYPE_FLOAT: + case mysql.MYSQL_TYPE_FLOAT: if len(paramValues) < (pos + 4) { - return ErrMalformPacket + return mysql.ErrMalformPacket } args[i] = math.Float32frombits(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) pos += 4 continue - case MYSQL_TYPE_DOUBLE: + case mysql.MYSQL_TYPE_DOUBLE: if len(paramValues) < (pos + 8) { - return ErrMalformPacket + return mysql.ErrMalformPacket } args[i] = math.Float64frombits(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) pos += 8 continue - case MYSQL_TYPE_DECIMAL, MYSQL_TYPE_NEWDECIMAL, MYSQL_TYPE_VARCHAR, - MYSQL_TYPE_BIT, MYSQL_TYPE_ENUM, MYSQL_TYPE_SET, MYSQL_TYPE_TINY_BLOB, - MYSQL_TYPE_MEDIUM_BLOB, MYSQL_TYPE_LONG_BLOB, MYSQL_TYPE_BLOB, - MYSQL_TYPE_VAR_STRING, MYSQL_TYPE_STRING, MYSQL_TYPE_GEOMETRY, - MYSQL_TYPE_DATE, MYSQL_TYPE_NEWDATE, - MYSQL_TYPE_TIMESTAMP, MYSQL_TYPE_DATETIME, MYSQL_TYPE_TIME: + case mysql.MYSQL_TYPE_DECIMAL, mysql.MYSQL_TYPE_NEWDECIMAL, mysql.MYSQL_TYPE_VARCHAR, + mysql.MYSQL_TYPE_BIT, mysql.MYSQL_TYPE_ENUM, mysql.MYSQL_TYPE_SET, mysql.MYSQL_TYPE_TINY_BLOB, + mysql.MYSQL_TYPE_MEDIUM_BLOB, mysql.MYSQL_TYPE_LONG_BLOB, mysql.MYSQL_TYPE_BLOB, + mysql.MYSQL_TYPE_VAR_STRING, mysql.MYSQL_TYPE_STRING, mysql.MYSQL_TYPE_GEOMETRY, + mysql.MYSQL_TYPE_DATE, mysql.MYSQL_TYPE_NEWDATE, + mysql.MYSQL_TYPE_TIMESTAMP, mysql.MYSQL_TYPE_DATETIME, mysql.MYSQL_TYPE_TIME: if len(paramValues) < (pos + 1) { - return ErrMalformPacket + return mysql.ErrMalformPacket } - v, isNull, n, err = LengthEncodedString(paramValues[pos:]) + v, isNull, n, err = mysql.LengthEncodedString(paramValues[pos:]) pos += n if err != nil { return errors.Trace(err) @@ -330,22 +330,22 @@ func (c *Conn) handleStmtSendLongData(data []byte) error { return nil } -func (c *Conn) handleStmtReset(data []byte) (*Result, error) { +func (c *Conn) handleStmtReset(data []byte) (*mysql.Result, error) { if len(data) < 4 { - return nil, ErrMalformPacket + return nil, mysql.ErrMalformPacket } id := binary.LittleEndian.Uint32(data[0:4]) s, ok := c.stmts[id] if !ok { - return nil, NewDefaultError(ER_UNKNOWN_STMT_HANDLER, 5, + return nil, mysql.NewDefaultError(mysql.ER_UNKNOWN_STMT_HANDLER, 5, strconv.FormatUint(uint64(id), 10), "stmt_reset") } s.ResetParams() - return NewResultReserveResultset(0), nil + return mysql.NewResultReserveResultset(0), nil } // stmt close command has no response