Skip to content

Commit 842491d

Browse files
committed
Connection draining for h2c connections 🚽
Connections that were upgraded to HTTP/2 by use of the H2cFilter can now be drained properly. The implementation is pretty ugly because Go does not have native support for connection draining on h2c connections, and as per golang/go#26682 this isn't a priority for the project.
1 parent a049e9b commit 842491d

File tree

6 files changed

+223
-57
lines changed

6 files changed

+223
-57
lines changed

e2e_http1_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ type http1Flavour struct {
1212
T *testing.T
1313
}
1414

15-
func (f http1Flavour) Serve(svc Service) Server {
15+
func (f http1Flavour) Serve(svc Service) *Server {
1616
s, err := Listen(svc, "localhost:0")
1717
require.NoError(f.T, err)
1818
return s
1919
}
2020

21-
func (f http1Flavour) URL(s Server) string {
21+
func (f http1Flavour) URL(s *Server) string {
2222
return fmt.Sprintf("http://%s", s.Listener().Addr())
2323
}
2424

@@ -31,7 +31,7 @@ type http1TLSFlavour struct {
3131
cert tls.Certificate
3232
}
3333

34-
func (f http1TLSFlavour) Serve(svc Service) Server {
34+
func (f http1TLSFlavour) Serve(svc Service) *Server {
3535
l, err := tls.Listen("tcp", "localhost:0", &tls.Config{
3636
Certificates: []tls.Certificate{f.cert},
3737
ClientAuth: tls.NoClientCert})
@@ -41,7 +41,7 @@ func (f http1TLSFlavour) Serve(svc Service) Server {
4141
return s
4242
}
4343

44-
func (f http1TLSFlavour) URL(s Server) string {
44+
func (f http1TLSFlavour) URL(s *Server) string {
4545
return fmt.Sprintf("https://%s", s.Listener().Addr())
4646
}
4747

e2e_http2_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@ type http2H2cFlavour struct {
1313
client Service
1414
}
1515

16-
func (f http2H2cFlavour) Serve(svc Service) Server {
16+
func (f http2H2cFlavour) Serve(svc Service) *Server {
1717
svc = svc.Filter(H2cFilter)
1818
s, err := Listen(svc, "localhost:0")
1919
require.NoError(f.T, err)
2020
return s
2121
}
2222

23-
func (f http2H2cFlavour) URL(s Server) string {
23+
func (f http2H2cFlavour) URL(s *Server) string {
2424
return fmt.Sprintf("http://%s", s.Listener().Addr())
2525
}
2626

@@ -34,7 +34,7 @@ type http2H2Flavour struct {
3434
cert tls.Certificate
3535
}
3636

37-
func (f http2H2Flavour) Serve(svc Service) Server {
37+
func (f http2H2Flavour) Serve(svc Service) *Server {
3838
l, err := tls.Listen("tcp", "localhost:0", &tls.Config{
3939
Certificates: []tls.Certificate{f.cert},
4040
ClientAuth: tls.NoClientCert,
@@ -45,7 +45,7 @@ func (f http2H2Flavour) Serve(svc Service) Server {
4545
return s
4646
}
4747

48-
func (f http2H2Flavour) URL(s Server) string {
48+
func (f http2H2Flavour) URL(s *Server) string {
4949
return fmt.Sprintf("https://%s", s.Listener().Addr())
5050
}
5151

e2e_test.go

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import (
2121
)
2222

2323
type e2eFlavour interface {
24-
Serve(Service) Server
25-
URL(Server) string
24+
Serve(Service) *Server
25+
URL(*Server) string
2626
Proto() string
2727
}
2828

@@ -79,7 +79,6 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour)
7979
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
8080
return net.Dial(network, addr)
8181
}}
82-
defer transport.CloseIdleConnections()
8382
Client = HttpService(transport).Filter(ErrorFilter)
8483
impl(t, http2H2cFlavour{T: t})
8584
})
@@ -92,7 +91,6 @@ func someFlavours(t *testing.T, only []string, impl func(*testing.T, e2eFlavour)
9291
AllowHTTP: false,
9392
TLSClientConfig: &tls.Config{
9493
InsecureSkipVerify: true}}
95-
defer transport.CloseIdleConnections()
9694
Client = HttpService(transport).Filter(ErrorFilter)
9795
impl(t, http2H2Flavour{
9896
T: t,
@@ -115,7 +113,7 @@ func TestE2E(t *testing.T) {
115113
})
116114
svc = svc.Filter(ErrorFilter)
117115
s := flav.Serve(svc)
118-
defer s.Stop()
116+
defer s.Stop(context.Background())
119117

120118
req := NewRequest(ctx, "GET", flav.URL(s), map[string]string{
121119
"a": "b"})
@@ -156,7 +154,7 @@ func TestE2EStreaming(t *testing.T) {
156154
})
157155
svc = svc.Filter(ErrorFilter)
158156
s := flav.Serve(svc)
159-
defer s.Stop()
157+
defer s.Stop(context.Background())
160158

161159
req := NewRequest(ctx, "GET", flav.URL(s), nil)
162160
rsp := req.Send().Response()
@@ -190,7 +188,7 @@ func TestE2EStreaming(t *testing.T) {
190188
})
191189
svc = svc.Filter(ErrorFilter)
192190
s := flav.Serve(svc)
193-
defer s.Stop()
191+
defer s.Stop(context.Background())
194192

195193
req := NewRequest(ctx, "GET", flav.URL(s), nil)
196194
reqS := Streamer()
@@ -233,7 +231,7 @@ func TestE2EDomainSocket(t *testing.T) {
233231

234232
s, err := Serve(svc, l)
235233
require.NoError(t, err)
236-
defer s.Stop()
234+
defer s.Stop(context.Background())
237235

238236
sockTransport := &httpcontrol.Transport{
239237
Dial: func(network, address string) (net.Conn, error) {
@@ -262,7 +260,7 @@ func TestE2EError(t *testing.T) {
262260
})
263261
svc = svc.Filter(ErrorFilter)
264262
s := flav.Serve(svc)
265-
defer s.Stop()
263+
defer s.Stop(context.Background())
266264

267265
req := NewRequest(ctx, "GET", flav.URL(s), nil)
268266
rsp := req.Send().Response()
@@ -290,7 +288,7 @@ func TestE2ECancellation(t *testing.T) {
290288
})
291289
svc = svc.Filter(ErrorFilter)
292290
s := flav.Serve(svc)
293-
defer s.Stop()
291+
defer s.Stop(context.Background())
294292

295293
ctx, cancel := context.WithCancel(context.Background())
296294
req := NewRequest(ctx, "GET", flav.URL(s), nil)
@@ -322,7 +320,7 @@ func TestE2ENoFollowRedirect(t *testing.T) {
322320
return rsp
323321
})
324322
s := flav.Serve(svc)
325-
defer s.Stop()
323+
defer s.Stop(context.Background())
326324

327325
ctx, cancel := context.WithCancel(context.Background())
328326
defer cancel()
@@ -354,14 +352,14 @@ func TestE2EProxiedStreamer(t *testing.T) {
354352
return rsp
355353
})
356354
s := flav.Serve(downstream)
357-
defer s.Stop()
355+
defer s.Stop(context.Background())
358356

359357
proxy := Service(func(req Request) Response {
360358
proxyReq := NewRequest(req, "GET", flav.URL(s), nil)
361359
return proxyReq.Send().Response()
362360
})
363361
ps := flav.Serve(proxy)
364-
defer ps.Stop()
362+
defer ps.Stop(context.Background())
365363

366364
req := NewRequest(ctx, "GET", flav.URL(ps), nil)
367365
rsp := req.Send().Response()
@@ -400,7 +398,7 @@ func TestE2EInfiniteContext(t *testing.T) {
400398
})
401399
svc = svc.Filter(ErrorFilter)
402400
s := flav.Serve(svc)
403-
defer s.Stop()
401+
defer s.Stop(context.Background())
404402

405403
req := NewRequest(ctx, "GET", flav.URL(s), map[string]string{
406404
"a": "b"})
@@ -435,7 +433,7 @@ func TestE2ERequestAutoChunking(t *testing.T) {
435433
})
436434
svc = svc.Filter(ErrorFilter)
437435
s := flav.Serve(svc)
438-
defer s.Stop()
436+
defer s.Stop(context.Background())
439437

440438
ctx, cancel := context.WithCancel(context.Background())
441439
defer cancel()
@@ -484,7 +482,7 @@ func TestE2EResponseAutoChunking(t *testing.T) {
484482
})
485483
svc = svc.Filter(ErrorFilter)
486484
s := flav.Serve(svc)
487-
defer s.Stop()
485+
defer s.Stop(context.Background())
488486

489487
ctx, cancel := context.WithCancel(context.Background())
490488
defer cancel()
@@ -548,7 +546,7 @@ func TestE2EStreamingCancellation(t *testing.T) {
548546
})
549547
svc = svc.Filter(ErrorFilter)
550548
s := flav.Serve(svc)
551-
defer s.Stop()
549+
defer s.Stop(context.Background())
552550

553551
ctx, cancel := context.WithCancel(context.Background())
554552
req := NewRequest(ctx, "GET", flav.URL(s), nil)
@@ -573,7 +571,7 @@ func BenchmarkRequestResponse(b *testing.B) {
573571
l, _ := net.ListenUnix("unix", addr)
574572
defer l.Close()
575573
s, _ := Serve(svc, l)
576-
defer s.Stop()
574+
defer s.Stop(context.Background())
577575

578576
sockTransport := &httpcontrol.Transport{
579577
Dial: func(network, address string) (net.Conn, error) {

h2c.go

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
package typhon
22

33
import (
4+
"bufio"
5+
"context"
6+
"net"
7+
"net/http"
48
"net/textproto"
9+
"sync"
510

11+
"github.com/monzo/terrors"
612
"golang.org/x/net/http/httpguts"
713
"golang.org/x/net/http2"
814
"golang.org/x/net/http2/h2c"
915
)
1016

11-
// H2cFilter adds HTTP/2 h2c upgrade support to the wrapped Service (as defined in RFC 7540 Sections 3.2, 3.4).
17+
// H2cFilter adds HTTP/2 h2c upgrade support to the wrapped Service (as defined in RFC 7540 §3.2, §3.4).
1218
func H2cFilter(req Request, svc Service) Response {
1319
h := req.Header
1420
// h2c with prior knowledge (RFC 7540 Section 3.4)
@@ -18,9 +24,121 @@ func H2cFilter(req Request, svc Service) Response {
1824
httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Connection")], "HTTP2-Settings")
1925
if isPrior || isUpgrade {
2026
rsp := NewResponse(req)
21-
h2s := &http2.Server{}
22-
h2c.NewHandler(HttpHandler(svc), h2s).ServeHTTP(rsp.Writer(), &req.Request)
27+
rw, h2s, err := setupH2cHijacker(req, rsp.Writer())
28+
if err != nil {
29+
return Response{Error: err}
30+
}
31+
h2c.NewHandler(HttpHandler(svc), h2s).ServeHTTP(rw, &req.Request)
2332
return rsp
2433
}
2534
return svc(req)
2635
}
36+
37+
// Dear reader: I'm sorry, the code below isn't fun. This is because Go's h2c implementation doesn't have support for
38+
// connection draining, and all the hooks that make would make this easy are unexported.
39+
//
40+
// If this ticket gets resolved this code can be dramatically simplified, but it is not a priority for the Go team:
41+
// https://github.com/golang/go/issues/26682
42+
//
43+
// 🤢
44+
45+
var h2cConns sync.Map // map[*Server]*h2cInfo
46+
47+
// h2cInfo stores information about connections that have been upgraded by a single Typhon server
48+
type h2cInfo struct {
49+
sync.Mutex
50+
conns []*hijackedConn
51+
h2s *http2.Server
52+
}
53+
54+
// hijackedConn represents a network connection that has been hijacked for a h2c upgrade. This is necessary because we
55+
// need to know when the connection has been closed, to know if/when graceful shutdown completes.
56+
type hijackedConn struct {
57+
net.Conn
58+
closed chan struct{}
59+
closeOnce sync.Once
60+
}
61+
62+
func (c *hijackedConn) Close() error {
63+
defer c.closeOnce.Do(func() { close(c.closed) })
64+
return c.Conn.Close()
65+
}
66+
67+
type h2cHijacker struct {
68+
http.ResponseWriter
69+
http.Hijacker
70+
hijacked func(*hijackedConn)
71+
}
72+
73+
func (h h2cHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
74+
c, r, err := h.Hijacker.Hijack()
75+
conn := &hijackedConn{
76+
Conn: c,
77+
closed: make(chan struct{})}
78+
h.hijacked(conn)
79+
return conn, r, err
80+
}
81+
82+
func shutdownH2c(ctx context.Context, srv *Server) {
83+
_h2c, ok := h2cConns.Load(srv)
84+
if !ok {
85+
return
86+
}
87+
h2c := _h2c.(*h2cInfo)
88+
h2c.Lock()
89+
defer h2c.Unlock()
90+
91+
gracefulCloseLoop:
92+
for _, c := range h2c.conns {
93+
select {
94+
case <-ctx.Done():
95+
break gracefulCloseLoop
96+
case <-c.closed:
97+
h2c.conns = h2c.conns[1:]
98+
}
99+
}
100+
// If any connections remain after gracefulCloseLoop, we need to forcefully close them
101+
for _, c := range h2c.conns {
102+
c.Close()
103+
h2c.conns = h2c.conns[1:]
104+
}
105+
h2cConns.Delete(srv)
106+
}
107+
108+
func setupH2cHijacker(req Request, rw http.ResponseWriter) (http.ResponseWriter, *http2.Server, error) {
109+
hijacker, ok := rw.(http.Hijacker)
110+
if !ok {
111+
err := terrors.InternalService("hijack_impossible", "Cannot hijack response; h2c upgrade impossible", nil)
112+
return nil, nil, err
113+
}
114+
srv := req.server
115+
if srv == nil {
116+
return rw, &http2.Server{}, nil
117+
}
118+
119+
h2c := &h2cInfo{
120+
h2s: &http2.Server{}}
121+
_h2c, loaded := h2cConns.LoadOrStore(srv, h2c)
122+
h2c = _h2c.(*h2cInfo)
123+
if !loaded {
124+
// http2.ConfigureServer wires up an unexported method within the http2 library so it gracefully drains h2c
125+
// connections when the http1 server is stopped. However, this happens asynchronously: the http1 server will
126+
// think it has shut down before the h2c connections have finished draining. To work around this, we add
127+
// a shutdown function of our own in the Typhon server which waits for connections to be drained, or if things
128+
// timeout before then to terminate them forcefully.
129+
http2.ConfigureServer(srv.srv, h2c.h2s)
130+
srv.addShutdownFunc(func(ctx context.Context) {
131+
shutdownH2c(ctx, srv)
132+
})
133+
}
134+
135+
h := h2cHijacker{
136+
ResponseWriter: rw,
137+
Hijacker: hijacker,
138+
hijacked: func(c *hijackedConn) {
139+
h2c.Lock()
140+
defer h2c.Unlock()
141+
h2c.conns = append(h2c.conns, c)
142+
}}
143+
return h, h2c.h2s, nil
144+
}

request.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ type Request struct {
1717
context.Context
1818
err error // Any error from request construction; read by Client
1919
hijacker http.Hijacker
20+
server *Server
2021
}
2122

2223
// unwrappedContext returns the most "unwrapped" Context possible for that in the request.

0 commit comments

Comments
 (0)