Skip to content

Commit 0858ed2

Browse files
authored
add test for tls connCheck #3025 (#3047)
* add a check for TLS connections.
1 parent 8a0c59b commit 0858ed2

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

internal/pool/conn_check.go

+5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package pool
44

55
import (
6+
"crypto/tls"
67
"errors"
78
"io"
89
"net"
@@ -16,6 +17,10 @@ func connCheck(conn net.Conn) error {
1617
// Reset previous timeout.
1718
_ = conn.SetDeadline(time.Time{})
1819

20+
// Check if tls.Conn.
21+
if c, ok := conn.(*tls.Conn); ok {
22+
conn = c.NetConn()
23+
}
1924
sysConn, ok := conn.(syscall.Conn)
2025
if !ok {
2126
return nil

internal/pool/conn_check_test.go

+18
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
package pool
44

55
import (
6+
"crypto/tls"
67
"net"
78
"net/http/httptest"
89
"time"
@@ -14,12 +15,17 @@ import (
1415
var _ = Describe("tests conn_check with real conns", func() {
1516
var ts *httptest.Server
1617
var conn net.Conn
18+
var tlsConn *tls.Conn
1719
var err error
1820

1921
BeforeEach(func() {
2022
ts = httptest.NewServer(nil)
2123
conn, err = net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second)
2224
Expect(err).NotTo(HaveOccurred())
25+
tlsTestServer := httptest.NewUnstartedServer(nil)
26+
tlsTestServer.StartTLS()
27+
tlsConn, err = tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, tlsTestServer.Listener.Addr().Network(), tlsTestServer.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true})
28+
Expect(err).NotTo(HaveOccurred())
2329
})
2430

2531
AfterEach(func() {
@@ -33,11 +39,23 @@ var _ = Describe("tests conn_check with real conns", func() {
3339
Expect(connCheck(conn)).To(HaveOccurred())
3440
})
3541

42+
It("good tls conn check", func() {
43+
Expect(connCheck(tlsConn)).NotTo(HaveOccurred())
44+
45+
Expect(tlsConn.Close()).NotTo(HaveOccurred())
46+
Expect(connCheck(tlsConn)).To(HaveOccurred())
47+
})
48+
3649
It("bad conn check", func() {
3750
Expect(conn.Close()).NotTo(HaveOccurred())
3851
Expect(connCheck(conn)).To(HaveOccurred())
3952
})
4053

54+
It("bad tls conn check", func() {
55+
Expect(tlsConn.Close()).NotTo(HaveOccurred())
56+
Expect(connCheck(tlsConn)).To(HaveOccurred())
57+
})
58+
4159
It("check conn deadline", func() {
4260
Expect(conn.SetDeadline(time.Now())).NotTo(HaveOccurred())
4361
time.Sleep(time.Millisecond * 10)

0 commit comments

Comments
 (0)