Skip to content

Commit 3deb7dc

Browse files
dvilaverdedvilaverde
and
dvilaverde
committed
allow setting the collation in auth handshake (go-mysql-org#860)
* Allow connect with context in order to provide configurable connect timeouts * support collations IDs greater than 255 on the auth handshake --------- Co-authored-by: dvilaverde <[email protected]>
1 parent 877bc05 commit 3deb7dc

File tree

4 files changed

+97
-9
lines changed

4 files changed

+97
-9
lines changed

client/auth.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ import (
55
"crypto/tls"
66
"encoding/binary"
77
"fmt"
8-
"github.com/pingcap/tidb/pkg/parser/charset"
98

109
. "github.com/go-mysql-org/go-mysql/mysql"
1110
"github.com/go-mysql-org/go-mysql/packet"
1211
"github.com/pingcap/errors"
12+
"github.com/pingcap/tidb/pkg/parser/charset"
1313
)
1414

1515
const defaultAuthPluginName = AUTH_NATIVE_PASSWORD
@@ -269,7 +269,7 @@ func (c *Conn) writeAuthHandshake() error {
269269
data[11] = 0x00
270270

271271
// Charset [1 byte]
272-
// use default collation id 33 here, is utf-8
272+
// use default collation id 33 here, is `utf8mb3_general_ci`
273273
collationName := c.collation
274274
if len(collationName) == 0 {
275275
collationName = DEFAULT_COLLATION_NAME
@@ -279,7 +279,15 @@ func (c *Conn) writeAuthHandshake() error {
279279
return fmt.Errorf("invalid collation name %s", collationName)
280280
}
281281

282-
data[12] = byte(collation.ID)
282+
// the MySQL protocol calls for the collation id to be sent as 1, where only the
283+
// lower 8 bits are used in this field. But wireshark shows that the first byte of
284+
// the 23 bytes of filler is used to send the right middle 8 bits of the collation id.
285+
// see https://github.com/mysql/mysql-server/pull/541
286+
data[12] = byte(collation.ID & 0xff)
287+
// if the collation ID is <= 255 the middle 8 bits are 0s so this is the equivalent of
288+
// padding the filler with a 0. If ID is > 255 then the first byte of filler will contain
289+
// the right middle 8 bits of the collation ID.
290+
data[13] = byte((collation.ID & 0xff00) >> 8)
283291

284292
// SSL Connection Request Packet
285293
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
@@ -301,8 +309,12 @@ func (c *Conn) writeAuthHandshake() error {
301309
}
302310

303311
// Filler [23 bytes] (all 0x00)
304-
pos := 13
305-
for ; pos < 13+23; pos++ {
312+
// the filler starts at position 13, but the first byte of the filler
313+
// has been set with the collation id earlier, so position 13 at this point
314+
// will be either 0x00, or the right middle 8 bits of the collation id.
315+
// Therefore, we start at position 14 and fill the remaining 22 bytes with 0x00.
316+
pos := 14
317+
for ; pos < 14+22; pos++ {
306318
data[pos] = 0
307319
}
308320

client/auth_test.go

+77-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package client
22

33
import (
4+
"net"
45
"testing"
56

6-
"github.com/go-mysql-org/go-mysql/mysql"
7+
"github.com/pingcap/tidb/pkg/parser/charset"
78
"github.com/stretchr/testify/require"
9+
10+
"github.com/go-mysql-org/go-mysql/mysql"
11+
"github.com/go-mysql-org/go-mysql/packet"
812
)
913

1014
func TestConnGenAttributes(t *testing.T) {
@@ -34,3 +38,75 @@ func TestConnGenAttributes(t *testing.T) {
3438
require.Subset(t, data, fixt)
3539
}
3640
}
41+
42+
func TestConnCollation(t *testing.T) {
43+
collations := []string{
44+
"big5_chinese_ci",
45+
"utf8_general_ci",
46+
"utf8mb4_0900_ai_ci",
47+
"utf8mb4_de_pb_0900_ai_ci",
48+
"utf8mb4_ja_0900_as_cs",
49+
"utf8mb4_0900_bin",
50+
"utf8mb4_zh_pinyin_tidb_as_cs",
51+
}
52+
53+
// test all supported collations by calling writeAuthHandshake() and reading the bytes
54+
// sent to the server to ensure the collation id is set correctly
55+
for _, c := range collations {
56+
collation, err := charset.GetCollationByName(c)
57+
require.NoError(t, err)
58+
server := sendAuthResponse(t, collation.Name)
59+
// read the all the bytes of the handshake response so that client goroutine can complete without blocking
60+
// on the server read.
61+
handShakeResponse := make([]byte, 128)
62+
_, err = server.Read(handShakeResponse)
63+
require.NoError(t, err)
64+
65+
// validate the collation id is set correctly
66+
// if the collation ID is <= 255 the collation ID is stored in the 12th byte
67+
if collation.ID <= 255 {
68+
require.Equal(t, byte(collation.ID), handShakeResponse[12])
69+
// the 13th byte should always be 0x00
70+
require.Equal(t, byte(0x00), handShakeResponse[13])
71+
} else {
72+
// if the collation ID is > 255 the collation ID is stored in the 12th and 13th bytes
73+
require.Equal(t, byte(collation.ID&0xff), handShakeResponse[12])
74+
require.Equal(t, byte(collation.ID>>8), handShakeResponse[13])
75+
}
76+
77+
// sanity check: validate the 22 bytes of filler with value 0x00 are set correctly
78+
for i := 14; i < 14+22; i++ {
79+
require.Equal(t, byte(0x00), handShakeResponse[i])
80+
}
81+
82+
// and finally the username
83+
username := string(handShakeResponse[36:40])
84+
require.Equal(t, "test", username)
85+
86+
require.NoError(t, server.Close())
87+
}
88+
}
89+
90+
func sendAuthResponse(t *testing.T, collation string) net.Conn {
91+
server, client := net.Pipe()
92+
c := &Conn{
93+
Conn: &packet.Conn{
94+
Conn: client,
95+
},
96+
authPluginName: "mysql_native_password",
97+
user: "test",
98+
db: "test",
99+
password: "test",
100+
proto: "tcp",
101+
collation: collation,
102+
salt: ([]byte)("123456781234567812345678"),
103+
}
104+
105+
go func() {
106+
err := c.writeAuthHandshake()
107+
require.NoError(t, err)
108+
err = c.Close()
109+
require.NoError(t, err)
110+
}()
111+
return server
112+
}

client/client_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ func (s *clientTestSuite) TestConn_SetCharset() {
235235
func (s *clientTestSuite) TestConn_SetCollationAfterConnect() {
236236
err := s.c.SetCollation("latin1_swedish_ci")
237237
require.Error(s.T(), err)
238+
require.ErrorContains(s.T(), err, "cannot set collation after connection is established")
238239
}
239240

240241
func (s *clientTestSuite) TestConn_SetCollation() {

client/conn.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,11 @@ func (c *Conn) SetCharset(charset string) error {
364364
}
365365

366366
func (c *Conn) SetCollation(collation string) error {
367-
if c.status == 0 {
368-
c.collation = collation
369-
} else {
367+
if len(c.serverVersion) != 0 {
370368
return errors.Trace(errors.Errorf("cannot set collation after connection is established"))
371369
}
372370

371+
c.collation = collation
373372
return nil
374373
}
375374

0 commit comments

Comments
 (0)