Skip to content

http2: perform connection health check #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ type Transport struct {
// waiting for their turn.
StrictMaxConcurrentStreams bool

// ReadIdleTimeout is the timeout after which a health check using ping
// frame will be carried out if no frame is received on the connection.
// Note that a ping response will is considered a received frame, so if
// there is no other traffic on the connection, the health check will
// be performed every ReadIdleTimeout interval.
// If zero, no health check is performed.
ReadIdleTimeout time.Duration

// PingTimeout is the timeout after which the connection will be closed
// if a response to Ping is not received.
// Defaults to 15s.
PingTimeout time.Duration

// t1, if non-nil, is the standard library Transport using
// this transport. Its settings are used (but not its
// RoundTrip method, etc).
Expand All @@ -131,6 +144,14 @@ func (t *Transport) disableCompression() bool {
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
}

func (t *Transport) pingTimeout() time.Duration {
if t.PingTimeout == 0 {
return 15 * time.Second
}
return t.PingTimeout

}

// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
// It returns an error if t1 has already been HTTP/2-enabled.
func ConfigureTransport(t1 *http.Transport) error {
Expand Down Expand Up @@ -674,6 +695,20 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
return cc, nil
}

func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
defer cancel()
err := cc.Ping(ctx)
if err != nil {
cc.closeForLostPing()
cc.t.connPool().MarkDead(cc)
return
}
}

func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
cc.mu.Lock()
defer cc.mu.Unlock()
Expand Down Expand Up @@ -834,14 +869,12 @@ func (cc *ClientConn) sendGoAway() error {
return nil
}

// Close closes the client connection immediately.
//
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *ClientConn) Close() error {
// closes the client connection immediately. In-flight requests are interrupted.
// err is sent to streams.
func (cc *ClientConn) closeForError(err error) error {
cc.mu.Lock()
defer cc.cond.Broadcast()
defer cc.mu.Unlock()
err := errors.New("http2: client connection force closed via ClientConn.Close")
for id, cs := range cc.streams {
select {
case cs.resc <- resAndError{err: err}:
Expand All @@ -854,6 +887,20 @@ func (cc *ClientConn) Close() error {
return cc.tconn.Close()
}

// Close closes the client connection immediately.
//
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *ClientConn) Close() error {
err := errors.New("http2: client connection force closed via ClientConn.Close")
return cc.closeForError(err)
}

// closes the client connection immediately. In-flight requests are interrupted.
func (cc *ClientConn) closeForLostPing() error {
err := errors.New("http2: client connection lost")
return cc.closeForError(err)
}

const maxAllocFrameSize = 512 << 10

// frameBuffer returns a scratch buffer suitable for writing DATA frames.
Expand Down Expand Up @@ -1706,8 +1753,17 @@ func (rl *clientConnReadLoop) run() error {
rl.closeWhenIdle = cc.t.disableKeepAlives() || cc.singleUse
gotReply := false // ever saw a HEADERS reply
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
var t *time.Timer
if readIdleTimeout != 0 {
t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
defer t.Stop()
}
for {
f, err := cc.fr.ReadFrame()
if t != nil {
t.Reset(readIdleTimeout)
}
if err != nil {
cc.vlogf("http2: Transport readFrame error on conn %p: (%T) %v", cc, err, err)
}
Expand Down
160 changes: 160 additions & 0 deletions http2/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3244,6 +3244,166 @@ func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
req.Header = http.Header{}
}

func TestTransportCloseAfterLostPing(t *testing.T) {
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.tr.PingTimeout = 1 * time.Second
ct.tr.ReadIdleTimeout = 1 * time.Second
ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
_, err := ct.tr.RoundTrip(req)
if err == nil || !strings.Contains(err.Error(), "client connection lost") {
return fmt.Errorf("expected to get error about \"connection lost\", got %v", err)
}
return nil
}
ct.server = func() error {
ct.greet()
<-clientDone
return nil
}
ct.run()
}

func TestTransportPingWhenReading(t *testing.T) {
testCases := []struct {
name string
readIdleTimeout time.Duration
serverResponseInterval time.Duration
expectedPingCount int
}{
{
name: "two pings in each serverResponseInterval",
readIdleTimeout: 400 * time.Millisecond,
serverResponseInterval: 1000 * time.Millisecond,
expectedPingCount: 4,
},
{
name: "one ping in each serverResponseInterval",
readIdleTimeout: 700 * time.Millisecond,
serverResponseInterval: 1000 * time.Millisecond,
expectedPingCount: 2,
},
{
name: "zero ping in each serverResponseInterval",
readIdleTimeout: 1000 * time.Millisecond,
serverResponseInterval: 500 * time.Millisecond,
expectedPingCount: 0,
},
{
name: "0 readIdleTimeout means no ping",
readIdleTimeout: 0 * time.Millisecond,
serverResponseInterval: 500 * time.Millisecond,
expectedPingCount: 0,
},
}

for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testTransportPingWhenReading(t, tc.readIdleTimeout, tc.serverResponseInterval, tc.expectedPingCount)
})
}
}

func testTransportPingWhenReading(t *testing.T, readIdleTimeout, serverResponseInterval time.Duration, expectedPingCount int) {
var pingCount int
clientDone := make(chan struct{})
ct := newClientTester(t)
ct.tr.PingTimeout = 10 * time.Millisecond
ct.tr.ReadIdleTimeout = readIdleTimeout
// guards the ct.fr.Write
var wmu sync.Mutex

ct.client = func() error {
defer ct.cc.(*net.TCPConn).CloseWrite()
defer close(clientDone)
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
res, err := ct.tr.RoundTrip(req)
if err != nil {
return fmt.Errorf("RoundTrip: %v", err)
}
defer res.Body.Close()
if res.StatusCode != 200 {
return fmt.Errorf("status code = %v; want %v", res.StatusCode, 200)
}
_, err = ioutil.ReadAll(res.Body)
return err
}

ct.server = func() error {
ct.greet()
var buf bytes.Buffer
enc := hpack.NewEncoder(&buf)
for {
f, err := ct.fr.ReadFrame()
if err != nil {
select {
case <-clientDone:
// If the client's done, it
// will have reported any
// errors on its side.
return nil
default:
return err
}
}
switch f := f.(type) {
case *WindowUpdateFrame, *SettingsFrame:
case *HeadersFrame:
if !f.HeadersEnded() {
return fmt.Errorf("headers should have END_HEADERS be ended: %v", f)
}
enc.WriteField(hpack.HeaderField{Name: ":status", Value: strconv.Itoa(200)})
ct.fr.WriteHeaders(HeadersFrameParam{
StreamID: f.StreamID,
EndHeaders: true,
EndStream: false,
BlockFragment: buf.Bytes(),
})

go func() {
for i := 0; i < 2; i++ {
wmu.Lock()
if err := ct.fr.WriteData(f.StreamID, false, []byte(fmt.Sprintf("hello, this is server data frame %d", i))); err != nil {
wmu.Unlock()
t.Error(err)
return
}
wmu.Unlock()
time.Sleep(serverResponseInterval)
}
wmu.Lock()
if err := ct.fr.WriteData(f.StreamID, true, []byte("hello, this is last server data frame")); err != nil {
wmu.Unlock()
t.Error(err)
return
}
wmu.Unlock()
}()
case *PingFrame:
pingCount++
wmu.Lock()
if err := ct.fr.WritePing(true, f.Data); err != nil {
wmu.Unlock()
return err
}
wmu.Unlock()
default:
return fmt.Errorf("Unexpected client frame %v", f)
}
}
}
ct.run()
if e, a := expectedPingCount, pingCount; e != a {
t.Errorf("expected receiving %d pings, got %d pings", e, a)

}
}

func TestTransportRetryAfterGOAWAY(t *testing.T) {
var dialer struct {
sync.Mutex
Expand Down