diff --git a/server/handler.go b/server/handler.go index 6cc582c58..c209e73ad 100644 --- a/server/handler.go +++ b/server/handler.go @@ -9,7 +9,7 @@ import ( "time" errors "gopkg.in/src-d/go-errors.v1" - "gopkg.in/src-d/go-mysql-server.v0" + sqle "gopkg.in/src-d/go-mysql-server.v0" "gopkg.in/src-d/go-mysql-server.v0/auth" "gopkg.in/src-d/go-mysql-server.v0/sql" @@ -173,23 +173,35 @@ func (h *Handler) handleKill(conn *mysql.Conn, query string) (bool, error) { // KILL CONNECTION and KILL should close the connection. KILL QUERY only // cancels the query. // - // https://dev.mysql.com/doc/refman/5.7/en/kill.html + // https://dev.mysql.com/doc/refman/8.0/en/kill.html + // + // KILL [CONNECTION | QUERY] processlist_id + // - KILL QUERY terminates the statement the connection is currently executing, + // but leaves the connection itself intact. + // - KILL CONNECTION is the same as KILL with no modifier: + // It terminates the connection associated with the given processlist_id, + // after terminating any statement the connection is executing. if s[1] == "query" { - logrus.Infof("kill query: id %v", id) + logrus.Infof("kill query: id %d", id) h.e.Catalog.Kill(id) } else { - logrus.Infof("kill connection: id %v, pid: %v", conn.ConnectionID, id) + connID, ok := h.e.Catalog.KillConnection(id) + if !ok { + return false, errConnectionNotFound.New(connID) + } + logrus.Infof("kill connection: id %d, pid: %d", connID, id) + h.mu.Lock() - c, ok := h.c[conn.ConnectionID] - delete(h.c, conn.ConnectionID) + c, ok := h.c[connID] + if ok { + delete(h.c, connID) + } h.mu.Unlock() - if !ok { - return false, errConnectionNotFound.New(conn.ConnectionID) + return false, errConnectionNotFound.New(connID) } - h.e.Catalog.KillConnection(uint32(id)) h.sm.CloseConn(c) c.Close() } diff --git a/server/handler_test.go b/server/handler_test.go index d3ef99d5c..fb379493f 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -1,12 +1,13 @@ package server import ( + "fmt" "net" "reflect" "testing" "unsafe" - "gopkg.in/src-d/go-mysql-server.v0" + sqle "gopkg.in/src-d/go-mysql-server.v0" "gopkg.in/src-d/go-mysql-server.v0/mem" "gopkg.in/src-d/go-mysql-server.v0/sql" "gopkg.in/src-d/go-vitess.v1/mysql" @@ -167,7 +168,7 @@ func TestHandlerKill(t *testing.T) { e, NewSessionManager( func(conn *mysql.Conn, addr string) sql.Session { - return sql.NewBaseSession() + return sql.NewSession(addr, "", "", conn.ConnectionID) }, opentracing.NoopTracer{}, "foo", @@ -197,15 +198,20 @@ func TestHandlerKill(t *testing.T) { assertNoConnProcesses(t, e, conn2.ConnectionID) - err = handler.ComQuery(conn2, "KILL 1", func(res *sqltypes.Result) error { + ctx1 := handler.sm.NewContextWithQuery(conn1, "SELECT 1") + ctx1, err = handler.e.Catalog.AddProcess(ctx1, sql.QueryProcess, "SELECT 1") + require.NoError(err) + + err = handler.ComQuery(conn2, "KILL "+fmt.Sprint(ctx1.Pid()), func(res *sqltypes.Result) error { return nil }) require.NoError(err) - require.Len(handler.sm.sessions, 0) + require.Len(handler.sm.sessions, 1) require.Len(handler.c, 1) - require.Equal(conn1, handler.c[1]) - assertNoConnProcesses(t, e, conn2.ConnectionID) + _, ok := handler.c[1] + require.False(ok) + assertNoConnProcesses(t, e, conn1.ConnectionID) } func assertNoConnProcesses(t *testing.T, e *sqle.Engine, conn uint32) { diff --git a/sql/processlist.go b/sql/processlist.go index 425311a58..8da5da2f9 100644 --- a/sql/processlist.go +++ b/sql/processlist.go @@ -160,17 +160,26 @@ func (pl *ProcessList) Kill(pid uint64) { pl.Done(pid) } -// KillConnection kills all processes from the given connection. -func (pl *ProcessList) KillConnection(conn uint32) { +// KillConnection terminates the connection associated with the given processlist_id, +// after terminating any statement the connection is executing. +func (pl *ProcessList) KillConnection(pid uint64) (uint32, bool) { pl.mu.Lock() defer pl.mu.Unlock() - for pid, proc := range pl.procs { - if proc.Connection == conn { + p, ok := pl.procs[pid] + if !ok { + return 0, false + } + + connID := p.Connection + for id, proc := range pl.procs { + if proc.Connection == connID { proc.Done() - delete(pl.procs, pid) + delete(pl.procs, id) } } + + return connID, ok } // Done removes the finished process with the given pid from the process list.