Skip to content

Commit d166653

Browse files
authored
Merge pull request #672 from skoef/serverAttributes
support client connection attributes on the server side
2 parents 051905b + 4c72b94 commit d166653

File tree

4 files changed

+102
-2
lines changed

4 files changed

+102
-2
lines changed

server/conn.go

+5
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ type Conn struct {
2020
capability uint32
2121
charset uint8
2222
authPluginName string
23+
attributes map[string]string
2324
connectionID uint32
2425
status uint16
2526
warnings uint16
@@ -160,6 +161,10 @@ func (c *Conn) Charset() uint8 {
160161
return c.charset
161162
}
162163

164+
func (c *Conn) Attributes() map[string]string {
165+
return c.attributes
166+
}
167+
163168
func (c *Conn) ConnectionID() uint32 {
164169
return c.connectionID
165170
}

server/handshake_resp.go

+40-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,14 @@ func (c *Conn) readHandshakeResponse() error {
3838
return nil
3939
}
4040

41-
// ignore connect attrs for now, the proxy does not support passing attrs to actual MySQL server
41+
// read connection attributes
42+
if c.capability&CLIENT_CONNECT_ATTRS > 0 {
43+
// readAttributes returns new position for further processing of data
44+
_, err = c.readAttributes(data, pos)
45+
if err != nil {
46+
return err
47+
}
48+
}
4249

4350
// try to authenticate the client
4451
return c.compareAuthData(c.authPluginName, authData)
@@ -198,3 +205,35 @@ func (c *Conn) handleAuthMatch(authData []byte, pos int) (bool, error) {
198205
}
199206
return true, nil
200207
}
208+
209+
func (c *Conn) readAttributes(data []byte, pos int) (int, error) {
210+
attrs := make(map[string]string)
211+
i := 0
212+
var key string
213+
214+
for {
215+
str, isNull, strLen, err := LengthEncodedString(data[pos:])
216+
217+
if err != nil {
218+
return -1, err
219+
}
220+
221+
if isNull {
222+
break
223+
}
224+
225+
// reading keys or values
226+
if i%2 == 0 {
227+
key = string(str)
228+
} else {
229+
attrs[key] = string(str)
230+
}
231+
232+
pos += strLen
233+
i++
234+
}
235+
236+
c.attributes = attrs
237+
238+
return pos, nil
239+
}

server/handshake_resp_test.go

+55
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,58 @@ func TestDecodeFirstPart(t *testing.T) {
5252
t.Fatalf("unexpected capability, got %d", c.capability)
5353
}
5454
}
55+
56+
func TestReadAttributes(t *testing.T) {
57+
var err error
58+
// example data from
59+
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41
60+
data := []byte{
61+
0xb2, 0x00, 0x00, 0x01, 0x85, 0xa2, 0x1e, 0x00, 0x00, 0x00,
62+
0x00, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
63+
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
64+
0x00, 0x00, 0x72, 0x6f, 0x6f, 0x74, 0x00, 0x14, 0x22, 0x50, 0x79, 0xa2,
65+
0x12, 0xd4, 0xe8, 0x82, 0xe5, 0xb3, 0xf4, 0x1a, 0x97, 0x75, 0x6b, 0xc8,
66+
0xbe, 0xdb, 0x9f, 0x80, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x5f, 0x6e, 0x61,
67+
0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72,
68+
0x64, 0x00, 0x61, 0x03, 0x5f, 0x6f, 0x73, 0x09, 0x64, 0x65, 0x62, 0x69,
69+
0x61, 0x6e, 0x36, 0x2e, 0x30, 0x0c, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e,
70+
0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x08, 0x6c, 0x69, 0x62, 0x6d, 0x79,
71+
0x73, 0x71, 0x6c, 0x04, 0x5f, 0x70, 0x69, 0x64, 0x05, 0x32, 0x32, 0x33,
72+
0x34, 0x34, 0x0f, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x76,
73+
0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x08, 0x35, 0x2e, 0x36, 0x2e, 0x36,
74+
0x2d, 0x6d, 0x39, 0x09, 0x5f, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72,
75+
0x6d, 0x06, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34, 0x03, 0x66, 0x6f, 0x6f,
76+
0x03, 0x62, 0x61, 0x72,
77+
}
78+
pos := 85
79+
80+
c := &Conn{}
81+
82+
pos, err = c.readAttributes(data, pos)
83+
if err != nil {
84+
t.Fatalf("unexpected error: got %v", err)
85+
}
86+
87+
if pos != 182 {
88+
t.Fatalf("unexpected position: got %d", pos)
89+
}
90+
91+
if len(c.attributes) != 6 {
92+
t.Fatalf("unexpected attribute length: got %d", len(c.attributes))
93+
}
94+
95+
fixture := map[string]string{
96+
"_client_name": "libmysql",
97+
"_client_version": "5.6.6-m9",
98+
"_os": "debian6.0",
99+
"_pid": "22344",
100+
"_platform": "x86_64",
101+
"foo": "bar",
102+
}
103+
104+
for k, v := range fixture {
105+
if vv := c.attributes[k]; vv != v {
106+
t.Fatalf("unexpected value for %s, got %s instead of %s", k, vv, v)
107+
}
108+
}
109+
}

server/server_conf.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ func NewServer(serverVersion string, collationId uint8, defaultAuthMethod string
7878
// panic(fmt.Sprintf("default auth method is not one of the allowed auth methods"))
7979
//}
8080
var capFlag = CLIENT_LONG_PASSWORD | CLIENT_LONG_FLAG | CLIENT_CONNECT_WITH_DB | CLIENT_PROTOCOL_41 |
81-
CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
81+
CLIENT_TRANSACTIONS | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_ATTRS |
82+
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
8283
if tlsConfig != nil {
8384
capFlag |= CLIENT_SSL
8485
}

0 commit comments

Comments
 (0)