Skip to content

Commit de77a42

Browse files
committed
webtransport: close underlying h3 connection (#2862)
1 parent cc4bd1b commit de77a42

File tree

4 files changed

+146
-21
lines changed

4 files changed

+146
-21
lines changed

p2p/transport/webtransport/conn.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
tpt "github.com/libp2p/go-libp2p/core/transport"
88

99
ma "github.com/multiformats/go-multiaddr"
10+
"github.com/quic-go/quic-go"
1011
"github.com/quic-go/webtransport-go"
1112
)
1213

@@ -31,16 +32,18 @@ type conn struct {
3132
session *webtransport.Session
3233

3334
scope network.ConnManagementScope
35+
qconn quic.Connection
3436
}
3537

3638
var _ tpt.CapableConn = &conn{}
3739

38-
func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope) *conn {
40+
func newConn(tr *transport, sess *webtransport.Session, sconn *connSecurityMultiaddrs, scope network.ConnManagementScope, qconn quic.Connection) *conn {
3941
return &conn{
4042
connSecurityMultiaddrs: sconn,
4143
transport: tr,
4244
session: sess,
4345
scope: scope,
46+
qconn: qconn,
4447
}
4548
}
4649

@@ -70,7 +73,9 @@ func (c *conn) allowWindowIncrease(size uint64) bool {
7073
func (c *conn) Close() error {
7174
c.scope.Done()
7275
c.transport.removeConn(c.session)
73-
return c.session.CloseWithError(0, "")
76+
err := c.session.CloseWithError(0, "")
77+
_ = c.qconn.CloseWithError(1, "")
78+
return err
7479
}
7580

7681
func (c *conn) IsClosed() bool { return c.session.Context().Err() != nil }

p2p/transport/webtransport/listener.go

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,61 @@ import (
1515
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
1616

1717
ma "github.com/multiformats/go-multiaddr"
18+
"github.com/quic-go/quic-go"
19+
"github.com/quic-go/quic-go/http3"
1820
"github.com/quic-go/webtransport-go"
1921
)
2022

2123
const queueLen = 16
2224
const handshakeTimeout = 10 * time.Second
2325

26+
type connKey struct{}
27+
28+
// negotiatingConn is a wrapper around a quic.Connection that lets us wrap it in
29+
// our own context for the duration of the upgrade process. Upgrading a quic
30+
// connection to an h3 connection to a webtransport session.
31+
type negotiatingConn struct {
32+
quic.Connection
33+
ctx context.Context
34+
cancel context.CancelFunc
35+
// stopClose is a function that stops the connection from being closed when
36+
// the context is done. Returns true if the connection close function was
37+
// not called.
38+
stopClose func() bool
39+
err error
40+
}
41+
42+
func (c *negotiatingConn) Unwrap() (quic.Connection, error) {
43+
defer c.cancel()
44+
if c.stopClose != nil {
45+
// unwrap the first time
46+
if !c.stopClose() {
47+
c.err = errTimeout
48+
}
49+
c.stopClose = nil
50+
}
51+
if c.err != nil {
52+
return nil, c.err
53+
}
54+
return c.Connection, nil
55+
}
56+
57+
func wrapConn(ctx context.Context, c quic.Connection, handshakeTimeout time.Duration) *negotiatingConn {
58+
ctx, cancel := context.WithTimeout(ctx, handshakeTimeout)
59+
stopClose := context.AfterFunc(ctx, func() {
60+
log.Debugf("failed to handshake on conn: %s", c.RemoteAddr())
61+
c.CloseWithError(1, "")
62+
})
63+
return &negotiatingConn{
64+
Connection: c,
65+
ctx: ctx,
66+
cancel: cancel,
67+
stopClose: stopClose,
68+
}
69+
}
70+
71+
var errTimeout = errors.New("timeout")
72+
2473
type listener struct {
2574
transport *transport
2675
isStaticTLSConf bool
@@ -56,6 +105,11 @@ func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf
56105
addr: reuseListener.Addr(),
57106
multiaddr: localMultiaddr,
58107
server: webtransport.Server{
108+
H3: http3.Server{
109+
ConnContext: func(ctx context.Context, c quic.Connection) context.Context {
110+
return context.WithValue(ctx, connKey{}, c)
111+
},
112+
},
59113
CheckOrigin: func(r *http.Request) bool { return true },
60114
},
61115
}
@@ -71,7 +125,8 @@ func newListener(reuseListener quicreuse.Listener, t *transport, isStaticTLSConf
71125
log.Debugw("serving failed", "addr", ln.Addr(), "error", err)
72126
return
73127
}
74-
go ln.server.ServeQUICConn(conn)
128+
wrapped := wrapConn(ln.ctx, conn, t.handshakeTimeout)
129+
go ln.server.ServeQUICConn(wrapped)
75130
}
76131
}()
77132
return ln, nil
@@ -137,13 +192,32 @@ func (l *listener) httpHandlerWithConnScope(w http.ResponseWriter, r *http.Reque
137192
return err
138193
}
139194

140-
conn := newConn(l.transport, sess, sconn, connScope)
195+
connVal := r.Context().Value(connKey{})
196+
if connVal == nil {
197+
log.Errorf("missing conn from context")
198+
sess.CloseWithError(1, "")
199+
return errors.New("invalid context")
200+
}
201+
nconn, ok := connVal.(*negotiatingConn)
202+
if !ok {
203+
log.Errorf("unexpected connection in context. invalid conn type: %T", nconn)
204+
sess.CloseWithError(1, "")
205+
return errors.New("invalid context")
206+
}
207+
qconn, err := nconn.Unwrap()
208+
if err != nil {
209+
log.Debugf("handshake timed out: %s", r.RemoteAddr)
210+
sess.CloseWithError(1, "")
211+
return err
212+
}
213+
214+
conn := newConn(l.transport, sess, sconn, connScope, qconn)
141215
l.transport.addConn(sess, conn)
142216
select {
143217
case l.queue <- conn:
144218
default:
145219
log.Debugw("accept queue full, dropping incoming connection", "peer", sconn.RemotePeer(), "addr", r.RemoteAddr, "error", err)
146-
sess.CloseWithError(1, "")
220+
conn.Close()
147221
return errors.New("accept queue full")
148222
}
149223

p2p/transport/webtransport/transport.go

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ func WithTLSClientConfig(c *tls.Config) Option {
6060
}
6161
}
6262

63+
func WithHandshakeTimeout(d time.Duration) Option {
64+
return func(t *transport) error {
65+
t.handshakeTimeout = d
66+
return nil
67+
}
68+
}
69+
6370
type transport struct {
6471
privKey ic.PrivKey
6572
pid peer.ID
@@ -78,8 +85,9 @@ type transport struct {
7885

7986
noise *noise.Transport
8087

81-
connMx sync.Mutex
82-
conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key
88+
connMx sync.Mutex
89+
conns map[quic.ConnectionTracingID]*conn // using quic-go's ConnectionTracingKey as map key
90+
handshakeTimeout time.Duration
8391
}
8492

8593
var _ tpt.Transport = &transport{}
@@ -99,13 +107,14 @@ func New(key ic.PrivKey, psk pnet.PSK, connManager *quicreuse.ConnManager, gater
99107
return nil, err
100108
}
101109
t := &transport{
102-
pid: id,
103-
privKey: key,
104-
rcmgr: rcmgr,
105-
gater: gater,
106-
clock: clock.New(),
107-
connManager: connManager,
108-
conns: map[quic.ConnectionTracingID]*conn{},
110+
pid: id,
111+
privKey: key,
112+
rcmgr: rcmgr,
113+
gater: gater,
114+
clock: clock.New(),
115+
connManager: connManager,
116+
conns: map[quic.ConnectionTracingID]*conn{},
117+
handshakeTimeout: handshakeTimeout,
109118
}
110119
for _, opt := range opts {
111120
if err := opt(t); err != nil {
@@ -159,7 +168,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
159168
}
160169

161170
maddr, _ := ma.SplitFunc(raddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_WEBTRANSPORT })
162-
sess, err := t.dial(ctx, maddr, url, sni, certHashes)
171+
sess, qconn, err := t.dial(ctx, maddr, url, sni, certHashes)
163172
if err != nil {
164173
return nil, err
165174
}
@@ -172,12 +181,12 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
172181
sess.CloseWithError(errorCodeConnectionGating, "")
173182
return nil, fmt.Errorf("secured connection gated")
174183
}
175-
conn := newConn(t, sess, sconn, scope)
184+
conn := newConn(t, sess, sconn, scope, qconn)
176185
t.addConn(sess, conn)
177186
return conn, nil
178187
}
179188

180-
func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, error) {
189+
func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string, certHashes []multihash.DecodedMultihash) (*webtransport.Session, quic.Connection, error) {
181190
var tlsConf *tls.Config
182191
if t.tlsClientConf != nil {
183192
tlsConf = t.tlsClientConf.Clone()
@@ -200,7 +209,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
200209
}
201210
conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease)
202211
if err != nil {
203-
return nil, err
212+
return nil, nil, err
204213
}
205214
dialer := webtransport.Dialer{
206215
DialAddr: func(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
@@ -210,12 +219,14 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
210219
}
211220
rsp, sess, err := dialer.Dial(ctx, url, nil)
212221
if err != nil {
213-
return nil, err
222+
conn.CloseWithError(1, "")
223+
return nil, nil, err
214224
}
215225
if rsp.StatusCode < 200 || rsp.StatusCode > 299 {
216-
return nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode)
226+
conn.CloseWithError(1, "")
227+
return nil, nil, fmt.Errorf("invalid response status code: %d", rsp.StatusCode)
217228
}
218-
return sess, err
229+
return sess, conn, err
219230
}
220231

221232
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) {

p2p/transport/webtransport/transport_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"fmt"
1212
"io"
1313
"net"
14+
"net/http"
1415
"os"
1516
"runtime"
1617
"sync/atomic"
@@ -827,3 +828,37 @@ func TestServerRotatesCertCorrectlyAfterSteps(t *testing.T) {
827828
require.True(t, found, "Failed after hour: %v", i)
828829
}
829830
}
831+
832+
func TestH3ConnClosed(t *testing.T) {
833+
_, serverKey := newIdentity(t)
834+
tr, err := libp2pwebtransport.New(serverKey, nil, newConnManager(t), nil, nil, libp2pwebtransport.WithHandshakeTimeout(1*time.Second))
835+
require.NoError(t, err)
836+
defer tr.(io.Closer).Close()
837+
ln, err := tr.Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1/webtransport"))
838+
require.NoError(t, err)
839+
defer ln.Close()
840+
841+
p, err := net.ListenPacket("udp", "127.0.0.1:0")
842+
require.NoError(t, err)
843+
conn, err := quic.Dial(context.Background(), p, ln.Addr(), &tls.Config{
844+
InsecureSkipVerify: true,
845+
NextProtos: []string{http3.NextProtoH3},
846+
}, nil)
847+
require.NoError(t, err)
848+
rt := &http3.SingleDestinationRoundTripper{
849+
Connection: conn,
850+
}
851+
rt.Start()
852+
require.Eventually(t, func() bool {
853+
c := http.Client{
854+
Transport: rt,
855+
Timeout: 1 * time.Second,
856+
}
857+
resp, err := c.Get(fmt.Sprintf("https://%s", ln.Addr().String()))
858+
if err != nil {
859+
return true
860+
}
861+
resp.Body.Close()
862+
return false
863+
}, 10*time.Second, 1*time.Second)
864+
}

0 commit comments

Comments
 (0)