From 248265da5a22884ae06c3a359a0c1f722f8a72d0 Mon Sep 17 00:00:00 2001 From: Reinier Schoof Date: Thu, 10 Feb 2022 09:55:19 +0100 Subject: [PATCH] support client connection attributes on the client side --- client/auth.go | 26 ++++++++++++++++++++++++++ client/auth_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ client/conn.go | 6 ++++++ 3 files changed, 74 insertions(+) create mode 100644 client/auth_test.go diff --git a/client/auth.go b/client/auth.go index 6f0ba5f89..8410edc40 100644 --- a/client/auth.go +++ b/client/auth.go @@ -139,6 +139,20 @@ func (c *Conn) genAuthResponse(authData []byte) ([]byte, bool, error) { } } +// generate connection attributes data +func (c *Conn) genAttributes() []byte { + if len(c.attributes) == 0 { + return nil + } + + attrData := make([]byte, 0) + for k, v := range c.attributes { + attrData = append(attrData, PutLengthEncodedString([]byte(k))...) + attrData = append(attrData, PutLengthEncodedString([]byte(v))...) + } + return append(PutLengthEncodedInt(uint64(len(attrData))), attrData...) +} + // See: http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse func (c *Conn) writeAuthHandshake() error { if !authPluginAllowed(c.authPluginName) { @@ -195,6 +209,12 @@ func (c *Conn) writeAuthHandshake() error { capability |= CLIENT_CONNECT_WITH_DB length += len(c.db) + 1 } + // connection attributes + attrData := c.genAttributes() + if len(attrData) > 0 { + capability |= CLIENT_CONNECT_ATTRS + length += len(attrData) + } data := make([]byte, length+4) @@ -264,6 +284,12 @@ func (c *Conn) writeAuthHandshake() error { // Assume native client during response pos += copy(data[pos:], c.authPluginName) data[pos] = 0x00 + pos++ + + // connection attributes + if len(attrData) > 0 { + copy(data[pos:], attrData) + } return c.WritePacket(data) } diff --git a/client/auth_test.go b/client/auth_test.go new file mode 100644 index 000000000..d5c79a1b5 --- /dev/null +++ b/client/auth_test.go @@ -0,0 +1,42 @@ +package client + +import ( + "bytes" + "testing" + + "github.com/go-mysql-org/go-mysql/mysql" +) + +func TestConnGenAttributes(t *testing.T) { + c := &Conn{ + // example data from + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse41 + attributes: map[string]string{ + "_os": "debian6.0", + "_client_name": "libmysql", + "_pid": "22344", + "_client_version": "5.6.6-m9", + "_platform": "x86_64", + "foo": "bar", + }, + } + + data := c.genAttributes() + + // the order of the attributes map cannot be guaranteed so to test the content + // of the attribute data we need to check its partial contents + + if len(data) != 98 { + t.Fatalf("unexpected data length, got %d", len(data)) + } + if data[0] != 0x61 { + t.Fatalf("unexpected length-encoded int, got %#x", data[0]) + } + + for k, v := range c.attributes { + fixt := append(mysql.PutLengthEncodedString([]byte(k)), mysql.PutLengthEncodedString([]byte(v))...) + if !bytes.Contains(data, fixt) { + t.Fatalf("%s attribute not found", k) + } + } +} diff --git a/client/conn.go b/client/conn.go index c5dd5ad2a..683cc66c9 100644 --- a/client/conn.go +++ b/client/conn.go @@ -28,6 +28,8 @@ type Conn struct { // client-set capabilities only ccaps uint32 + attributes map[string]string + status uint16 charset string @@ -302,6 +304,10 @@ func (c *Conn) Rollback() error { return errors.Trace(err) } +func (c *Conn) SetAttributes(attributes map[string]string) { + c.attributes = attributes +} + func (c *Conn) SetCharset(charset string) error { if c.charset == charset { return nil