Skip to content

Commit 8551be2

Browse files
dvilaverdedvilaverdedveedenlance6716
authored
allow setting the collation in auth handshake (#860)
* allow setting the collation in auth handshake * Allow connect with context in order to provide configurable connect timeouts * fixing linting error * support collations IDs greater than 255 on the auth handshake * Update client/auth.go Co-authored-by: Daniël van Eeden <[email protected]> * address PR feedback * fixing comments * fixing comments * fix linting errors * restore tests that were commented out accidently * fixing more typos in the comments * Apply suggestions from code review Co-authored-by: lance6716 <[email protected]> * addressing PR feedback --------- Co-authored-by: dvilaverde <[email protected]> Co-authored-by: Daniël van Eeden <[email protected]> Co-authored-by: lance6716 <[email protected]>
1 parent 7c31dc4 commit 8551be2

File tree

4 files changed

+145
-8
lines changed

4 files changed

+145
-8
lines changed

client/auth.go

+26-4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
. "github.com/go-mysql-org/go-mysql/mysql"
1010
"github.com/go-mysql-org/go-mysql/packet"
1111
"github.com/pingcap/errors"
12+
"github.com/pingcap/tidb/pkg/parser/charset"
1213
)
1314

1415
const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
@@ -268,8 +269,25 @@ func (c *Conn) writeAuthHandshake() error {
268269
data[11] = 0x00
269270

270271
// Charset [1 byte]
271-
// use default collation id 33 here, is utf-8
272-
data[12] = DEFAULT_COLLATION_ID
272+
// use default collation id 33 here, is `utf8mb3_general_ci`
273+
collationName := c.collation
274+
if len(collationName) == 0 {
275+
collationName = DEFAULT_COLLATION_NAME
276+
}
277+
collation, err := charset.GetCollationByName(collationName)
278+
if err != nil {
279+
return fmt.Errorf("invalid collation name %s", collationName)
280+
}
281+
282+
// the MySQL protocol calls for the collation id to be sent as 1, where only the
283+
// lower 8 bits are used in this field. But wireshark shows that the first byte of
284+
// the 23 bytes of filler is used to send the right middle 8 bits of the collation id.
285+
// see https://github.com/mysql/mysql-server/pull/541
286+
data[12] = byte(collation.ID & 0xff)
287+
// if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of
288+
// padding the filler with a 0. If ID is > 255 then the first byte of filler will contain
289+
// the right middle 8 bits of the collation ID.
290+
data[13] = byte((collation.ID & 0xff00) >> 8)
273291

274292
// SSL Connection Request Packet
275293
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
@@ -291,8 +309,12 @@ func (c *Conn) writeAuthHandshake() error {
291309
}
292310

293311
// Filler [23 bytes] (all 0x00)
294-
pos := 13
295-
for ; pos < 13+23; pos++ {
312+
// the filler starts at position 13, but the first byte of the filler
313+
// has been set with the collation id earlier, so position 13 at this point
314+
// will be either 0x00, or the right middle 8 bits of the collation id.
315+
// Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00.
316+
pos := 14
317+
for ; pos < 14+22; pos++ {
296318
data[pos] = 0
297319
}
298320

client/auth_test.go

+77-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package client
22

33
import (
4+
"net"
45
"testing"
56

6-
"github.com/go-mysql-org/go-mysql/mysql"
7+
"github.com/pingcap/tidb/pkg/parser/charset"
78
"github.com/stretchr/testify/require"
9+
10+
"github.com/go-mysql-org/go-mysql/mysql"
11+
"github.com/go-mysql-org/go-mysql/packet"
812
)
913

1014
func TestConnGenAttributes(t *testing.T) {
@@ -34,3 +38,75 @@ func TestConnGenAttributes(t *testing.T) {
3438
require.Subset(t, data, fixt)
3539
}
3640
}
41+
42+
func TestConnCollation(t *testing.T) {
43+
collations := []string{
44+
"big5_chinese_ci",
45+
"utf8_general_ci",
46+
"utf8mb4_0900_ai_ci",
47+
"utf8mb4_de_pb_0900_ai_ci",
48+
"utf8mb4_ja_0900_as_cs",
49+
"utf8mb4_0900_bin",
50+
"utf8mb4_zh_pinyin_tidb_as_cs",
51+
}
52+
53+
// test all supported collations by calling writeAuthHandshake() and reading the bytes
54+
// sent to the server to ensure the collation id is set correctly
55+
for _, c := range collations {
56+
collation, err := charset.GetCollationByName(c)
57+
require.NoError(t, err)
58+
server := sendAuthResponse(t, collation.Name)
59+
// read the all the bytes of the handshake response so that client goroutine can complete without blocking
60+
// on the server read.
61+
handShakeResponse := make([]byte, 128)
62+
_, err = server.Read(handShakeResponse)
63+
require.NoError(t, err)
64+
65+
// validate the collation id is set correctly
66+
// if the collation ID is <= 255 the collation ID is stored in the 12th byte
67+
if collation.ID <= 255 {
68+
require.Equal(t, byte(collation.ID), handShakeResponse[12])
69+
// the 13th byte should always be 0x00
70+
require.Equal(t, byte(0x00), handShakeResponse[13])
71+
} else {
72+
// if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes
73+
require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12])
74+
require.Equal(t, byte(collation.ID>>8), handShakeResponse[13])
75+
}
76+
77+
// sanity check: validate the 22 bytes of filler with value 0x00 are set correctly
78+
for i := 14; i < 14+22; i++ {
79+
require.Equal(t, byte(0x00), handShakeResponse[i])
80+
}
81+
82+
// and finally the username
83+
username := string(handShakeResponse[36:40])
84+
require.Equal(t, "test", username)
85+
86+
require.NoError(t, server.Close())
87+
}
88+
}
89+
90+
func sendAuthResponse(t *testing.T, collation string) net.Conn {
91+
server, client := net.Pipe()
92+
c := &Conn{
93+
Conn: &packet.Conn{
94+
Conn: client,
95+
},
96+
authPluginName: "mysql_native_password",
97+
user: "test",
98+
db: "test",
99+
password: "test",
100+
proto: "tcp",
101+
collation: collation,
102+
salt: ([]byte)("123456781234567812345678"),
103+
}
104+
105+
go func() {
106+
err := c.writeAuthHandshake()
107+
require.NoError(t, err)
108+
err = c.Close()
109+
require.NoError(t, err)
110+
}()
111+
return server
112+
}

client/client_test.go

+21-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@ func TestClientSuite(t *testing.T) {
3131
func (s *clientTestSuite) SetupSuite() {
3232
var err error
3333
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
34-
s.c, err = Connect(addr, *testUser, *testPassword, "")
34+
s.c, err = Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
35+
// test the collation logic, but this is essentially a no-op since
36+
// the collation set is the default value
37+
_ = conn.SetCollation(mysql.DEFAULT_COLLATION_NAME)
38+
})
3539
require.NoError(s.T(), err)
3640

3741
var result *mysql.Result
@@ -228,6 +232,22 @@ func (s *clientTestSuite) TestConn_SetCharset() {
228232
require.NoError(s.T(), err)
229233
}
230234

235+
func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
236+
err := s.c.SetCollation("latin1_swedish_ci")
237+
require.Error(s.T(), err)
238+
require.ErrorContains(s.T(), err, "cannot set collation after connection is established")
239+
}
240+
241+
func (s *clientTestSuite) TestConn_SetCollation() {
242+
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
243+
_, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) {
244+
// test the collation logic
245+
_ = conn.SetCollation("invalid_collation")
246+
})
247+
248+
require.Error(s.T(), err)
249+
}
250+
231251
func (s *clientTestSuite) testStmt_DropTable() {
232252
str := `drop table if exists mixer_test_stmt`
233253

client/conn.go

+21-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ type Conn struct {
3737
status uint16
3838

3939
charset string
40+
// sets the collation to be set on the auth handshake, this does not issue a 'set names' command
41+
collation string
4042

4143
salt []byte
4244
authPluginName string
@@ -67,15 +69,19 @@ func Connect(addr string, user string, password string, dbName string, options .
6769
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
6870
defer cancel()
6971

70-
dialer := &net.Dialer{}
72+
return ConnectWithContext(ctx, addr, user, password, dbName, options...)
73+
}
7174

75+
// ConnectWithContext to a MySQL addr using the provided context.
76+
func ConnectWithContext(ctx context.Context, addr string, user string, password string, dbName string, options ...func(*Conn)) (*Conn, error) {
77+
dialer := &net.Dialer{}
7278
return ConnectWithDialer(ctx, "", addr, user, password, dbName, dialer.DialContext, options...)
7379
}
7480

7581
// Dialer connects to the address on the named network using the provided context.
7682
type Dialer func(ctx context.Context, network, address string) (net.Conn, error)
7783

78-
// Connect to a MySQL server using the given Dialer.
84+
// ConnectWithDialer to a MySQL server using the given Dialer.
7985
func ConnectWithDialer(ctx context.Context, network string, addr string, user string, password string, dbName string, dialer Dialer, options ...func(*Conn)) (*Conn, error) {
8086
c := new(Conn)
8187

@@ -357,6 +363,19 @@ func (c *Conn) SetCharset(charset string) error {
357363
}
358364
}
359365

366+
func (c *Conn) SetCollation(collation string) error {
367+
if len(c.serverVersion) != 0 {
368+
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
369+
}
370+
371+
c.collation = collation
372+
return nil
373+
}
374+
375+
func (c *Conn) GetCollation() string {
376+
return c.collation
377+
}
378+
360379
func (c *Conn) FieldList(table string, wildcard string) ([]*Field, error) {
361380
if err := c.writeCommandStrStr(COM_FIELD_LIST, table, wildcard); err != nil {
362381
return nil, errors.Trace(err)

0 commit comments

Comments
 (0)