From 7b7fbf22bab1ceb9322707c7942d03b950ad0e12 Mon Sep 17 00:00:00 2001 From: Reinier Schoof Date: Wed, 16 Jun 2021 14:30:41 +0200 Subject: [PATCH] expose capability and charset of connections to server refactored readFirstPacket for improved testability --- server/conn.go | 9 +++++++++ server/handshake_resp.go | 8 ++++++-- server/handshake_resp_test.go | 24 ++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/server/conn.go b/server/conn.go index 7f890fd7c..3fcbc7d3f 100644 --- a/server/conn.go +++ b/server/conn.go @@ -17,6 +17,7 @@ type Conn struct { serverConf *Server capability uint32 + charset uint8 authPluginName string connectionID uint32 status uint16 @@ -133,6 +134,14 @@ func (c *Conn) GetUser() string { return c.user } +func (c *Conn) Capability() uint32 { + return c.capability +} + +func (c *Conn) Charset() uint8 { + return c.charset +} + func (c *Conn) ConnectionID() uint32 { return c.connectionID } diff --git a/server/handshake_resp.go b/server/handshake_resp.go index e4f1e9be4..584226747 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -50,6 +50,10 @@ func (c *Conn) readFirstPart() ([]byte, int, error) { return nil, 0, err } + return c.decodeFirstPart(data) +} + +func (c *Conn) decodeFirstPart(data []byte) ([]byte, int, error) { pos := 0 // check CLIENT_PROTOCOL_41 @@ -67,8 +71,8 @@ func (c *Conn) readFirstPart() ([]byte, int, error) { //skip max packet size pos += 4 - //charset, skip, if you want to use another charset, use set names - //c.collation = CollationId(data[pos]) + // connection's default character set as defined + c.charset = data[pos] pos++ //skip reserved 23[00] diff --git a/server/handshake_resp_test.go b/server/handshake_resp_test.go index f1c7fdf4a..32e1426b6 100644 --- a/server/handshake_resp_test.go +++ b/server/handshake_resp_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "testing" "github.com/go-mysql-org/go-mysql/mysql" @@ -28,3 +29,26 @@ func TestReadAuthData(t *testing.T) { t.Fatalf("expected %d read bytes, got %d", len(data)-1, readBytes) } } + +func TestDecodeFirstPart(t *testing.T) { + data := []byte{141, 174, 255, 1, 0, 0, 0, 1, 8} + + c := &Conn{} + + result, pos, err := c.decodeFirstPart(data) + if err != nil { + t.Fatalf("expected nil error, got %v", err) + } + if !bytes.Equal(result, data) { + t.Fatal("expected same data, got something else") + } + if pos != 32 { + t.Fatalf("unexpected pos, got %d", pos) + } + if c.capability != 33533581 { + t.Fatalf("unexpected capability, got %d", c.capability) + } + if c.charset != 8 { + t.Fatalf("unexpected capability, got %d", c.capability) + } +}