Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 6ee998b

Browse files
authored
Also check sockets bind to tcp6 and fail on all closed sockets (#824)
Also check sockets bind to tcp6 and fail on all closed sockets
2 parents 2e82b0a + 604c376 commit 6ee998b

File tree

6 files changed

+66
-33
lines changed

6 files changed

+66
-33
lines changed

internal/sockstate/netstat_darwin.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ import (
1212
// elements that satisfy the accept function
1313
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
1414
// (juanjux) TODO: not implemented
15-
logrus.Info("Connection checking not implemented for Darwin")
16-
return []sockTabEntry{}, nil
15+
logrus.Warn("Connection checking not implemented for Darwin")
16+
return nil, ErrSocketCheckNotImplemented.New()
1717
}
1818

1919
func GetConnInode(c *net.TCPConn) (n uint64, err error) {

internal/sockstate/netstat_linux.go

+40-14
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ import (
2020
)
2121

2222
const (
23-
pathTCPTab = "/proc/net/tcp"
23+
pathTCP4Tab = "/proc/net/tcp"
24+
pathTCP6Tab = "/proc/net/tcp6"
2425
ipv4StrLen = 8
26+
ipv6StrLen = 32
2527
)
2628

2729
type procFd struct {
@@ -120,6 +122,23 @@ func parseIPv4(s string) (net.IP, error) {
120122
return ip, nil
121123
}
122124

125+
func parseIPv6(s string) (net.IP, error) {
126+
ip := make(net.IP, net.IPv6len)
127+
const grpLen = 4
128+
i, j := 0, 4
129+
for len(s) != 0 {
130+
grp := s[0:8]
131+
u, err := strconv.ParseUint(grp, 16, 32)
132+
binary.LittleEndian.PutUint32(ip[i:j], uint32(u))
133+
if err != nil {
134+
return nil, err
135+
}
136+
i, j = i+grpLen, j+grpLen
137+
s = s[8:]
138+
}
139+
return ip, nil
140+
}
141+
123142
func parseAddr(s string) (*sockAddr, error) {
124143
fields := strings.Split(s, ":")
125144
if len(fields) < 2 {
@@ -130,6 +149,8 @@ func parseAddr(s string) (*sockAddr, error) {
130149
switch len(fields[0]) {
131150
case ipv4StrLen:
132151
ip, err = parseIPv4(fields[0])
152+
case ipv6StrLen:
153+
ip, err = parseIPv6(fields[0])
133154
default:
134155
log.Fatal("Badly formatted connection address:", s)
135156
}
@@ -192,21 +213,26 @@ func parseSocktab(r io.Reader, accept AcceptFn) ([]sockTabEntry, error) {
192213
// tcpSocks returns a slice of active TCP sockets containing only those
193214
// elements that satisfy the accept function
194215
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
195-
f, err := os.Open(pathTCPTab)
196-
defer func() {
197-
_ = f.Close()
198-
}()
199-
if err != nil {
200-
return nil, err
201-
}
216+
paths := [2]string{pathTCP4Tab, pathTCP6Tab}
217+
var allTabs []sockTabEntry
218+
for _, p := range paths {
219+
f, err := os.Open(p)
220+
defer func() {
221+
_ = f.Close()
222+
}()
223+
if err != nil {
224+
return nil, err
225+
}
202226

203-
tabs, err := parseSocktab(f, accept)
204-
if err != nil {
205-
return nil, err
206-
}
227+
t, err := parseSocktab(f, accept)
228+
if err != nil {
229+
return nil, err
230+
}
231+
allTabs = append(allTabs, t...)
207232

208-
extractProcInfo(tabs)
209-
return tabs, nil
233+
}
234+
extractProcInfo(allTabs)
235+
return allTabs, nil
210236
}
211237

212238
// GetConnInode returns the Linux inode number of a TCP connection

internal/sockstate/netstat_windows.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ import (
1212
// elements that satisfy the accept function
1313
func tcpSocks(accept AcceptFn) ([]sockTabEntry, error) {
1414
// (juanjux) TODO: not implemented
15-
logrus.Info("Connection checking not implemented for Windows")
16-
return []sockTabEntry{}, nil
15+
logrus.Warn("Connection checking not implemented for Windows")
16+
return nil, ErrSocketCheckNotImplemented.New()
1717
}
1818

1919
func GetConnInode(c *net.TCPConn) (n uint64, err error) {

internal/sockstate/sockstate.go

+16-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ import (
88
type SockState uint8
99

1010
const (
11-
Finished = iota
12-
Broken
11+
Broken = iota
1312
Other
1413
Error
1514
)
@@ -37,12 +36,24 @@ func GetInodeSockState(port int, inode uint64) (SockState, error) {
3736

3837
switch len(socks) {
3938
case 0:
40-
return Finished, nil
39+
return Broken, nil
4140
case 1:
42-
if socks[0].State == CloseWait {
41+
switch socks[0].State {
42+
case CloseWait:
43+
fallthrough
44+
case TimeWait:
45+
fallthrough
46+
case FinWait1:
47+
fallthrough
48+
case FinWait2:
49+
fallthrough
50+
case Close:
51+
fallthrough
52+
case Closing:
4353
return Broken, nil
54+
default:
55+
return Other, nil
4456
}
45-
return Other, nil
4657
default: // more than one sock for inode, impossible?
4758
return Error, ErrMultipleSocketsForInode.New()
4859
}

server/handler.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -211,16 +211,12 @@ func (h *Handler) ComQuery(
211211
for {
212212
select {
213213
case <-quit:
214-
// timeout or other errors detected by the calling routine
215214
return
216215
default:
217216
}
218217

219218
st, err := sockstate.GetInodeSockState(t.Port, inode)
220219
switch st {
221-
case sockstate.Finished:
222-
// Not Linux OSs will also exit here
223-
return
224220
case sockstate.Broken:
225221
errChan <- ErrConnectionWasClosed.New()
226222
return
@@ -243,6 +239,7 @@ rowLoop:
243239

244240
if r.RowsAffected == rowsBatch {
245241
if err := callback(r); err != nil {
242+
close(quit)
246243
return err
247244
}
248245

@@ -276,13 +273,12 @@ rowLoop:
276273
}
277274
timer.Reset(waitTime)
278275
}
276+
close(quit)
279277

280278
if err := rows.Close(); err != nil {
281279
return err
282280
}
283281

284-
close(quit)
285-
286282
// Even if r.RowsAffected = 0, the callback must be
287283
// called to update the state in the go-vitess' listener
288284
// and avoid returning errors when the query doesn't

server/handler_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ func TestHandlerKill(t *testing.T) {
165165
require.Len(handler.c, 2)
166166
require.Equal(conntainer1, handler.c[1])
167167
require.Equal(conntainer2, handler.c[2])
168-
169168
assertNoConnProcesses(t, e, conn2.ConnectionID)
170169

171170
ctx1 := handler.sm.NewContextWithQuery(conn1, "SELECT 1")
@@ -256,6 +255,7 @@ func TestHandlerTimeout(t *testing.T) {
256255
})
257256
require.NoError(err)
258257
}
258+
259259
func TestOkClosedConnection(t *testing.T) {
260260
require := require.New(t)
261261
e := setupMemDB(require)
@@ -282,11 +282,11 @@ func TestOkClosedConnection(t *testing.T) {
282282
0,
283283
)
284284
h.AddNetConnection(&conn)
285-
c2 := newConn(2)
286-
h.NewConnection(c2)
285+
c := newConn(1)
286+
h.NewConnection(c)
287287

288288
q := fmt.Sprintf("SELECT SLEEP(%d)", tcpCheckerSleepTime*4)
289-
err = h.ComQuery(c2, q, func(res *sqltypes.Result) error {
289+
err = h.ComQuery(c, q, func(res *sqltypes.Result) error {
290290
return nil
291291
})
292292
require.NoError(err)

0 commit comments

Comments
 (0)