Skip to content

Commit b2e6f87

Browse files
author
Chao Xu
committed
Only check connection health if the connection read loop has been idle
1 parent bc0d6c6 commit b2e6f87

File tree

3 files changed

+224
-63
lines changed

3 files changed

+224
-63
lines changed

http2/client_conn_pool.go

-53
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
package http2
88

99
import (
10-
"context"
1110
"crypto/tls"
1211
"net/http"
1312
"sync"
14-
"time"
1513
)
1614

1715
// ClientConnPool manages a pool of HTTP/2 client connections.
@@ -43,16 +41,6 @@ type clientConnPool struct {
4341
dialing map[string]*dialCall // currently in-flight dials
4442
keys map[*ClientConn][]string
4543
addConnCalls map[string]*addConnCall // in-flight addConnIfNeede calls
46-
47-
// TODO: figure out a way to allow user to configure pingPeriod and
48-
// pingTimeout.
49-
pingPeriod time.Duration // how often pings are sent on idle
50-
// connections. The connection will be closed if response is not
51-
// received within pingTimeout. 0 means no periodic pings.
52-
pingTimeout time.Duration // connection will be force closed if a Ping
53-
// response is not received within pingTimeout.
54-
pingStops map[*ClientConn]chan struct{} // channels to stop the
55-
// periodic Pings.
5644
}
5745

5846
func (p *clientConnPool) GetClientConn(req *http.Request, addr string) (*ClientConn, error) {
@@ -231,54 +219,13 @@ func (p *clientConnPool) addConnLocked(key string, cc *ClientConn) {
231219
if p.keys == nil {
232220
p.keys = make(map[*ClientConn][]string)
233221
}
234-
if p.pingStops == nil {
235-
p.pingStops = make(map[*ClientConn]chan struct{})
236-
}
237222
p.conns[key] = append(p.conns[key], cc)
238223
p.keys[cc] = append(p.keys[cc], key)
239-
if p.pingPeriod != 0 {
240-
p.pingStops[cc] = p.pingConnection(key, cc)
241-
}
242-
}
243-
244-
// TODO: ping all connections at the same tick to save tickers?
245-
func (p *clientConnPool) pingConnection(key string, cc *ClientConn) chan struct{} {
246-
done := make(chan struct{})
247-
go func() {
248-
ticker := time.NewTicker(p.pingPeriod)
249-
defer ticker.Stop()
250-
for {
251-
select {
252-
case <-done:
253-
return
254-
default:
255-
}
256-
257-
select {
258-
case <-done:
259-
return
260-
case <-ticker.C:
261-
ctx, _ := context.WithTimeout(context.Background(), p.pingTimeout)
262-
err := cc.Ping(ctx)
263-
if err != nil {
264-
cc.closeForLostPing()
265-
p.MarkDead(cc)
266-
}
267-
}
268-
}
269-
}()
270-
return done
271224
}
272225

273226
func (p *clientConnPool) MarkDead(cc *ClientConn) {
274227
p.mu.Lock()
275228
defer p.mu.Unlock()
276-
277-
if done, ok := p.pingStops[cc]; ok {
278-
close(done)
279-
delete(p.pingStops, cc)
280-
}
281-
282229
for _, key := range p.keys[cc] {
283230
vv, ok := p.conns[key]
284231
if !ok {

http2/transport.go

+92-5
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,23 @@ type Transport struct {
108108
// waiting for their turn.
109109
StrictMaxConcurrentStreams bool
110110

111+
// PingPeriod controls how often pings are sent on idle connections to
112+
// check the liveness of the connection. The connection will be closed
113+
// if response is not received within PingTimeout.
114+
// 0 means no periodic pings. Defaults to 0.
115+
PingPeriod time.Duration
116+
// PingTimeout is the timeout after which the connection will be closed
117+
// if a response to Ping is not received.
118+
// 0 means no periodic pings. Defaults to 0.
119+
PingTimeout time.Duration
120+
// ReadIdleTimeout is the timeout after which the periodic ping for
121+
// connection health check will begin if no frame is received on the
122+
// connection.
123+
// The health check will stop once there is frame received on the
124+
// connection.
125+
// Defaults to 60s.
126+
ReadIdleTimeout time.Duration
127+
111128
// t1, if non-nil, is the standard library Transport using
112129
// this transport. Its settings are used (but not its
113130
// RoundTrip method, etc).
@@ -140,10 +157,6 @@ func ConfigureTransport(t1 *http.Transport) error {
140157

141158
func configureTransport(t1 *http.Transport) (*Transport, error) {
142159
connPool := new(clientConnPool)
143-
// TODO: figure out a way to allow user to configure pingPeriod and
144-
// pingTimeout.
145-
connPool.pingPeriod = 5 * time.Second
146-
connPool.pingTimeout = 1 * time.Second
147160
t2 := &Transport{
148161
ConnPool: noDialClientConnPool{connPool},
149162
t1: t1,
@@ -243,6 +256,8 @@ type ClientConn struct {
243256

244257
wmu sync.Mutex // held while writing; acquire AFTER mu if holding both
245258
werr error // first write error that has occurred
259+
260+
healthCheckStopCh chan struct{}
246261
}
247262

248263
// clientStream is the state for a single HTTP/2 stream. One of these
@@ -678,6 +693,49 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
678693
return cc, nil
679694
}
680695

696+
func (cc *ClientConn) healthCheck(stop chan struct{}) {
697+
pingPeriod := cc.t.PingPeriod
698+
pingTimeout := cc.t.PingTimeout
699+
if pingPeriod == 0 || pingTimeout == 0 {
700+
return
701+
}
702+
ticker := time.NewTicker(pingPeriod)
703+
defer ticker.Stop()
704+
for {
705+
select {
706+
case <-stop:
707+
return
708+
case <-ticker.C:
709+
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
710+
err := cc.Ping(ctx)
711+
cancel()
712+
if err != nil {
713+
cc.closeForLostPing()
714+
cc.t.connPool().MarkDead(cc)
715+
return
716+
}
717+
}
718+
}
719+
}
720+
721+
func (cc *ClientConn) startHealthCheck() {
722+
if cc.healthCheckStopCh != nil {
723+
// a health check is already running
724+
return
725+
}
726+
cc.healthCheckStopCh = make(chan struct{})
727+
go cc.healthCheck(cc.healthCheckStopCh)
728+
}
729+
730+
func (cc *ClientConn) stopHealthCheck() {
731+
if cc.healthCheckStopCh == nil {
732+
// no health check running
733+
return
734+
}
735+
close(cc.healthCheckStopCh)
736+
cc.healthCheckStopCh = nil
737+
}
738+
681739
func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
682740
cc.mu.Lock()
683741
defer cc.mu.Unlock()
@@ -1717,13 +1775,42 @@ func (rl *clientConnReadLoop) cleanup() {
17171775
cc.mu.Unlock()
17181776
}
17191777

1778+
type frameAndError struct {
1779+
f Frame
1780+
err error
1781+
}
1782+
1783+
func nonBlockingReadFrame(fr *Framer) chan frameAndError {
1784+
feCh := make(chan frameAndError)
1785+
go func() {
1786+
f, err := fr.ReadFrame()
1787+
feCh <- frameAndError{f: f, err: err}
1788+
}()
1789+
return feCh
1790+
}
1791+
17201792
func (rl *clientConnReadLoop) run() error {
17211793
cc := rl.cc
17221794
rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
17231795
gotReply := false // ever saw a HEADERS reply
17241796
gotSettings := false
17251797
for {
1726-
f, err := cc.fr.ReadFrame()
1798+
var fe frameAndError
1799+
feCh := nonBlockingReadFrame(cc.fr)
1800+
to := cc.t.ReadIdleTimeout
1801+
if to == 0 {
1802+
to = 60 * time.Second
1803+
}
1804+
readIdleTimer := time.NewTimer(to)
1805+
select {
1806+
case fe = <-feCh:
1807+
cc.stopHealthCheck()
1808+
readIdleTimer.Stop()
1809+
case <-readIdleTimer.C:
1810+
cc.startHealthCheck()
1811+
fe = <-feCh
1812+
}
1813+
f, err := fe.f, fe.err
17271814
if err != nil {
17281815
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
17291816
}

http2/transport_test.go

+132-5
Original file line numberDiff line numberDiff line change
@@ -3247,11 +3247,9 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
32473247
func TestTransportCloseAfterLostPing(t *testing.T) {
32483248
clientDone := make(chan struct{})
32493249
ct := newClientTester(t)
3250-
connPool := new(clientConnPool)
3251-
connPool.pingPeriod = 1 * time.Second
3252-
connPool.pingTimeout = 100 * time.Millisecond
3253-
connPool.t = ct.tr
3254-
ct.tr.ConnPool = connPool
3250+
ct.tr.PingPeriod = 1 * time.Second
3251+
ct.tr.PingTimeout = 1 * time.Second
3252+
ct.tr.ReadIdleTimeout = 1 * time.Second
32553253
ct.client = func() error {
32563254
defer ct.cc.(*net.TCPConn).CloseWrite()
32573255
defer close(clientDone)
@@ -3270,6 +3268,135 @@ func TestTransportCloseAfterLostPing(t *testing.T) {
32703268
ct.run()
32713269
}
32723270

3271+
func TestTransportPingWhenReading(t *testing.T) {
3272+
testTransportPingWhenReading(t, 50*time.Millisecond, 100*time.Millisecond)
3273+
testTransportPingWhenReading(t, 100*time.Millisecond, 50*time.Millisecond)
3274+
}
3275+
3276+
func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration) {
3277+
var pinged bool
3278+
clientBodyBytes := []byte("hello, this is client")
3279+
clientDone := make(chan struct{})
3280+
ct := newClientTester(t)
3281+
ct.tr.PingPeriod = 10 * time.Millisecond
3282+
ct.tr.PingTimeout = 10 * time.Millisecond
3283+
ct.tr.ReadIdleTimeout = readIdleTimeout
3284+
// guards the ct.fr.Write
3285+
var wmu sync.Mutex
3286+
ct.client = func() error {
3287+
defer ct.cc.(*net.TCPConn).CloseWrite()
3288+
defer close(clientDone)
3289+
3290+
req, err := http.NewRequest("PUT", "https://dummy.tld/", bytes.NewReader(clientBodyBytes))
3291+
if err != nil {
3292+
return err
3293+
}
3294+
res, err := ct.tr.RoundTrip(req)
3295+
if err != nil {
3296+
return fmt.Errorf("RoundTrip: %v", err)
3297+
}
3298+
defer res.Body.Close()
3299+
if res.StatusCode != 200 {
3300+
return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
3301+
}
3302+
_, err = ioutil.ReadAll(res.Body)
3303+
return err
3304+
}
3305+
ct.server = func() error {
3306+
ct.greet()
3307+
var buf bytes.Buffer
3308+
enc := hpack.NewEncoder(&buf)
3309+
var dataRecv int
3310+
var closed bool
3311+
for {
3312+
f, err := ct.fr.ReadFrame()
3313+
if err != nil {
3314+
select {
3315+
case <-clientDone:
3316+
// If the client's done, it
3317+
// will have reported any
3318+
// errors on its side.
3319+
return nil
3320+
default:
3321+
return err
3322+
}
3323+
}
3324+
switch f := f.(type) {
3325+
case *WindowUpdateFrame, *SettingsFrame, *HeadersFrame:
3326+
case *DataFrame:
3327+
dataLen := len(f.Data())
3328+
if dataLen > 0 {
3329+
err := func() error {
3330+
wmu.Lock()
3331+
defer wmu.Unlock()
3332+
if dataRecv == 0 {
3333+
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
3334+
ct.fr.WriteHeaders(HeadersFrameParam{
3335+
StreamID: f.StreamID,
3336+
EndHeaders: true,
3337+
EndStream: false,
3338+
BlockFragment: buf.Bytes(),
3339+
})
3340+
}
3341+
if err := ct.fr.WriteWindowUpdate(0, uint32(dataLen)); err != nil {
3342+
return err
3343+
}
3344+
if err := ct.fr.WriteWindowUpdate(f.StreamID, uint32(dataLen)); err != nil {
3345+
return err
3346+
}
3347+
return nil
3348+
}()
3349+
if err != nil {
3350+
return err
3351+
}
3352+
}
3353+
dataRecv += dataLen
3354+
3355+
if !closed && dataRecv == len(clientBodyBytes) {
3356+
closed = true
3357+
go func() {
3358+
for i := 0; i < 10; i++ {
3359+
wmu.Lock()
3360+
if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
3361+
wmu.Unlock()
3362+
t.Error(err)
3363+
return
3364+
}
3365+
wmu.Unlock()
3366+
time.Sleep(serverResponseInterval)
3367+
}
3368+
wmu.Lock()
3369+
if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server frame")); err != nil {
3370+
wmu.Unlock()
3371+
t.Error(err)
3372+
return
3373+
}
3374+
wmu.Unlock()
3375+
}()
3376+
}
3377+
case *PingFrame:
3378+
pinged = true
3379+
if serverResponseInterval > readIdleTimeout {
3380+
wmu.Lock()
3381+
if err := ct.fr.WritePing(true, f.Data); err != nil {
3382+
wmu.Unlock()
3383+
return err
3384+
}
3385+
wmu.Unlock()
3386+
} else {
3387+
return fmt.Errorf("Unexpected ping frame: %v", f)
3388+
}
3389+
default:
3390+
return fmt.Errorf("Unexpected client frame %v", f)
3391+
}
3392+
}
3393+
}
3394+
ct.run()
3395+
if serverResponseInterval > readIdleTimeout && !pinged {
3396+
t.Errorf("expect ping")
3397+
}
3398+
}
3399+
32733400
func TestTransportRetryAfterGOAWAY(t *testing.T) {
32743401
var dialer struct {
32753402
sync.Mutex

0 commit comments

Comments
 (0)