Skip to content

Commit 8dab929

Browse files
committed
http2: make Transport retry on server's GOAWAY graceful shutdown
Debugged & wrote with Tom Bergan. Updates golang/go#18083 Change-Id: I00a1cb748fe9c0f01c5bd4b8d1ac4438b56f1f8c Reviewed-on: https://go-review.googlesource.com/33971 Run-TryBot: Brad Fitzpatrick <[email protected]> TryBot-Result: Gobot Gobot <[email protected]> Reviewed-by: Tom Bergan <[email protected]>
1 parent 0c96df3 commit 8dab929

File tree

4 files changed

+104
-12
lines changed

4 files changed

+104
-12
lines changed

http2/go18.go

+9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package http2
88

99
import (
1010
"crypto/tls"
11+
"io"
1112
"net/http"
1213
)
1314

@@ -39,3 +40,11 @@ func configureServer18(h1 *http.Server, h2 *Server) error {
3940
func shouldLogPanic(panicValue interface{}) bool {
4041
return panicValue != nil && panicValue != http.ErrAbortHandler
4142
}
43+
44+
func reqGetBody(req *http.Request) func() (io.ReadCloser, error) {
45+
return req.GetBody
46+
}
47+
48+
func reqBodyIsNoBody(body io.ReadCloser) bool {
49+
return body == http.NoBody
50+
}

http2/not_go18.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66

77
package http2
88

9-
import "net/http"
9+
import (
10+
"io"
11+
"net/http"
12+
)
1013

1114
func configureServer18(h1 *http.Server, h2 *Server) error {
1215
// No IdleTimeout to sync prior to Go 1.8.
@@ -16,3 +19,9 @@ func configureServer18(h1 *http.Server, h2 *Server) error {
1619
func shouldLogPanic(panicValue interface{}) bool {
1720
return panicValue != nil
1821
}
22+
23+
func reqGetBody(req *http.Request) func() (io.ReadCloser, error) {
24+
return nil
25+
}
26+
27+
func reqBodyIsNoBody(io.ReadCloser) bool { return false }

http2/transport.go

+57-6
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ type clientStream struct {
191191
ID uint32
192192
resc chan resAndError
193193
bufPipe pipe // buffered pipe with the flow-controlled response payload
194+
startedWrite bool // started request body write; guarded by cc.mu
194195
requestedGzip bool
195196
on100 func() // optional code to run if get a 100 continue response
196197

@@ -332,8 +333,10 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
332333
}
333334
traceGotConn(req, cc)
334335
res, err := cc.RoundTrip(req)
335-
if shouldRetryRequest(req, err) {
336-
continue
336+
if err != nil {
337+
if req, err = shouldRetryRequest(req, err); err == nil {
338+
continue
339+
}
337340
}
338341
if err != nil {
339342
t.vlogf("RoundTrip failure: %v", err)
@@ -355,12 +358,41 @@ func (t *Transport) CloseIdleConnections() {
355358
var (
356359
errClientConnClosed = errors.New("http2: client conn is closed")
357360
errClientConnUnusable = errors.New("http2: client conn not usable")
361+
362+
errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY")
363+
errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written")
358364
)
359365

360-
func shouldRetryRequest(req *http.Request, err error) bool {
361-
// TODO: retry GET requests (no bodies) more aggressively, if shutdown
362-
// before response.
363-
return err == errClientConnUnusable
366+
// shouldRetryRequest is called by RoundTrip when a request fails to get
367+
// response headers. It is always called with a non-nil error.
368+
// It returns either a request to retry (either the same request, or a
369+
// modified clone), or an error if the request can't be replayed.
370+
func shouldRetryRequest(req *http.Request, err error) (*http.Request, error) {
371+
switch err {
372+
default:
373+
return nil, err
374+
case errClientConnUnusable, errClientConnGotGoAway:
375+
return req, nil
376+
case errClientConnGotGoAwayAfterSomeReqBody:
377+
// If the Body is nil (or http.NoBody), it's safe to reuse
378+
// this request and its Body.
379+
if req.Body == nil || reqBodyIsNoBody(req.Body) {
380+
return req, nil
381+
}
382+
// Otherwise we depend on the Request having its GetBody
383+
// func defined.
384+
getBody := reqGetBody(req) // Go 1.8: getBody = req.GetBody
385+
if getBody == nil {
386+
return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error")
387+
}
388+
body, err := getBody()
389+
if err != nil {
390+
return nil, err
391+
}
392+
newReq := *req
393+
newReq.Body = body
394+
return &newReq, nil
395+
}
364396
}
365397

366398
func (t *Transport) dialClientConn(addr string, singleUse bool) (*ClientConn, error) {
@@ -513,6 +545,15 @@ func (cc *ClientConn) setGoAway(f *GoAwayFrame) {
513545
if old != nil && old.ErrCode != ErrCodeNo {
514546
cc.goAway.ErrCode = old.ErrCode
515547
}
548+
last := f.LastStreamID
549+
for streamID, cs := range cc.streams {
550+
if streamID > last {
551+
select {
552+
case cs.resc <- resAndError{err: errClientConnGotGoAway}:
553+
default:
554+
}
555+
}
556+
}
516557
}
517558

518559
func (cc *ClientConn) CanTakeNewRequest() bool {
@@ -773,6 +814,13 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (*http.Response, error) {
773814
cs.abortRequestBodyWrite(errStopReqBodyWrite)
774815
}
775816
if re.err != nil {
817+
if re.err == errClientConnGotGoAway {
818+
cc.mu.Lock()
819+
if cs.startedWrite {
820+
re.err = errClientConnGotGoAwayAfterSomeReqBody
821+
}
822+
cc.mu.Unlock()
823+
}
776824
cc.forgetStreamID(cs.ID)
777825
return nil, re.err
778826
}
@@ -2013,6 +2061,9 @@ func (t *Transport) getBodyWriterState(cs *clientStream, body io.Reader) (s body
20132061
resc := make(chan error, 1)
20142062
s.resc = resc
20152063
s.fn = func() {
2064+
cs.cc.mu.Lock()
2065+
cs.startedWrite = true
2066+
cs.cc.mu.Unlock()
20162067
resc <- cs.writeRequestBody(body, cs.req.Body)
20172068
}
20182069
s.delay = t.expectContinueTimeout()

http2/transport_test.go

+28-5
Original file line numberDiff line numberDiff line change
@@ -2747,7 +2747,6 @@ func TestTransportCancelDataResponseRace(t *testing.T) {
27472747
}
27482748

27492749
func TestTransportRetryAfterGOAWAY(t *testing.T) {
2750-
t.Skip("to be unskipped by https://go-review.googlesource.com/c/33971/")
27512750
var dialer struct {
27522751
sync.Mutex
27532752
count int
@@ -2765,6 +2764,9 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
27652764
dialer.Lock()
27662765
defer dialer.Unlock()
27672766
dialer.count++
2767+
if dialer.count == 3 {
2768+
return nil, errors.New("unexpected number of dials")
2769+
}
27682770
cc, err := net.Dial("tcp", ln.Addr().String())
27692771
if err != nil {
27702772
return nil, fmt.Errorf("dial error: %v", err)
@@ -2797,10 +2799,20 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
27972799
go func() {
27982800
req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
27992801
res, err := tr.RoundTrip(req)
2800-
t.Logf("client got %T, %v", res, err)
2802+
if res != nil {
2803+
res.Body.Close()
2804+
if got := res.Header.Get("Foo"); got != "bar" {
2805+
err = fmt.Errorf("foo header = %q; want bar", got)
2806+
}
2807+
}
2808+
if err != nil {
2809+
err = fmt.Errorf("RoundTrip: %v", err)
2810+
}
28012811
errs <- err
28022812
}()
28032813

2814+
connToClose := make(chan io.Closer, 2)
2815+
28042816
// Server for the first request.
28052817
go func() {
28062818
var ct *clientTester
@@ -2810,6 +2822,7 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
28102822
return
28112823
}
28122824

2825+
connToClose <- ct.cc
28132826
ct.greet()
28142827
hf, err := ct.firstHeaders()
28152828
if err != nil {
@@ -2821,7 +2834,6 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
28212834
errs <- fmt.Errorf("server1 failed writing GOAWAY: %v", err)
28222835
return
28232836
}
2824-
ct.cc.(*net.TCPConn).Close()
28252837
errs <- nil
28262838
}()
28272839

@@ -2834,25 +2846,27 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
28342846
return
28352847
}
28362848

2849+
connToClose <- ct.cc
28372850
ct.greet()
28382851
hf, err := ct.firstHeaders()
28392852
if err != nil {
28402853
errs <- fmt.Errorf("server2 failed reading HEADERS: %v", err)
28412854
return
28422855
}
2843-
t.Logf("server2 Got %v", hf)
2856+
t.Logf("server2 got %v", hf)
28442857

28452858
var buf bytes.Buffer
28462859
enc := hpack.NewEncoder(&buf)
28472860
enc.WriteField(hpack.HeaderField{Name: ":status", Value: "200"})
2861+
enc.WriteField(hpack.HeaderField{Name: "foo", Value: "bar"})
28482862
err = ct.fr.WriteHeaders(HeadersFrameParam{
28492863
StreamID: hf.StreamID,
28502864
EndHeaders: true,
28512865
EndStream: false,
28522866
BlockFragment: buf.Bytes(),
28532867
})
28542868
if err != nil {
2855-
errs <- fmt.Errorf("server2 failed writin responseg HEADERS: %v", err)
2869+
errs <- fmt.Errorf("server2 failed writing response HEADERS: %v", err)
28562870
} else {
28572871
errs <- nil
28582872
}
@@ -2868,4 +2882,13 @@ func TestTransportRetryAfterGOAWAY(t *testing.T) {
28682882
t.Errorf("timed out")
28692883
}
28702884
}
2885+
2886+
for {
2887+
select {
2888+
case c := <-connToClose:
2889+
c.Close()
2890+
default:
2891+
return
2892+
}
2893+
}
28712894
}

0 commit comments

Comments
 (0)