Skip to content

Commit 62947d9

Browse files
simplify generation of stateless reset tokens (#4858)
1 parent 9950b4c commit 62947d9

16 files changed

+224
-233
lines changed

client.go

+28-16
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ type client struct {
2222
tlsConf *tls.Config
2323
config *Config
2424

25-
connIDGenerator ConnectionIDGenerator
26-
srcConnID protocol.ConnectionID
27-
destConnID protocol.ConnectionID
25+
connIDGenerator ConnectionIDGenerator
26+
statelessResetter *statelessResetter
27+
srcConnID protocol.ConnectionID
28+
destConnID protocol.ConnectionID
2829

2930
initialPacketNumber protocol.PacketNumber
3031
hasNegotiatedVersion bool
@@ -137,13 +138,14 @@ func dial(
137138
ctx context.Context,
138139
conn sendConn,
139140
connIDGenerator ConnectionIDGenerator,
141+
statelessResetter *statelessResetter,
140142
packetHandlers packetHandlerManager,
141143
tlsConf *tls.Config,
142144
config *Config,
143145
onClose func(),
144146
use0RTT bool,
145147
) (quicConn, error) {
146-
c, err := newClient(conn, connIDGenerator, config, tlsConf, onClose, use0RTT)
148+
c, err := newClient(conn, connIDGenerator, statelessResetter, config, tlsConf, onClose, use0RTT)
147149
if err != nil {
148150
return nil, err
149151
}
@@ -162,7 +164,15 @@ func dial(
162164
return c.conn, nil
163165
}
164166

165-
func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config *Config, tlsConf *tls.Config, onClose func(), use0RTT bool) (*client, error) {
167+
func newClient(
168+
sendConn sendConn,
169+
connIDGenerator ConnectionIDGenerator,
170+
statelessResetter *statelessResetter,
171+
config *Config,
172+
tlsConf *tls.Config,
173+
onClose func(),
174+
use0RTT bool,
175+
) (*client, error) {
166176
srcConnID, err := connIDGenerator.GenerateConnectionID()
167177
if err != nil {
168178
return nil, err
@@ -172,17 +182,18 @@ func newClient(sendConn sendConn, connIDGenerator ConnectionIDGenerator, config
172182
return nil, err
173183
}
174184
c := &client{
175-
connIDGenerator: connIDGenerator,
176-
srcConnID: srcConnID,
177-
destConnID: destConnID,
178-
sendConn: sendConn,
179-
use0RTT: use0RTT,
180-
onClose: onClose,
181-
tlsConf: tlsConf,
182-
config: config,
183-
version: config.Versions[0],
184-
handshakeChan: make(chan struct{}),
185-
logger: utils.DefaultLogger.WithPrefix("client"),
185+
connIDGenerator: connIDGenerator,
186+
statelessResetter: statelessResetter,
187+
srcConnID: srcConnID,
188+
destConnID: destConnID,
189+
sendConn: sendConn,
190+
use0RTT: use0RTT,
191+
onClose: onClose,
192+
tlsConf: tlsConf,
193+
config: config,
194+
version: config.Versions[0],
195+
handshakeChan: make(chan struct{}),
196+
logger: utils.DefaultLogger.WithPrefix("client"),
186197
}
187198
return c, nil
188199
}
@@ -197,6 +208,7 @@ func (c *client) dial(ctx context.Context) error {
197208
c.destConnID,
198209
c.srcConnID,
199210
c.connIDGenerator,
211+
c.statelessResetter,
200212
c.config,
201213
c.tlsConf,
202214
c.initialPacketNumber,

client_test.go

+31-3
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ var _ = Describe("Client", func() {
3333
destConnID protocol.ConnectionID,
3434
srcConnID protocol.ConnectionID,
3535
connIDGenerator ConnectionIDGenerator,
36+
statelessResetToken *statelessResetter,
3637
conf *Config,
3738
tlsConf *tls.Config,
3839
initialPacketNumber protocol.PacketNumber,
@@ -107,6 +108,7 @@ var _ = Describe("Client", func() {
107108
_ protocol.ConnectionID,
108109
_ protocol.ConnectionID,
109110
_ ConnectionIDGenerator,
111+
_ *statelessResetter,
110112
_ *Config,
111113
_ *tls.Config,
112114
_ protocol.PacketNumber,
@@ -124,7 +126,15 @@ var _ = Describe("Client", func() {
124126
conn.EXPECT().HandshakeComplete().Return(c)
125127
return conn
126128
}
127-
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, false)
129+
cl, err := newClient(
130+
packetConn,
131+
&protocol.DefaultConnectionIDGenerator{},
132+
newStatelessResetter(nil),
133+
populateConfig(config),
134+
tlsConf,
135+
nil,
136+
false,
137+
)
128138
Expect(err).ToNot(HaveOccurred())
129139
cl.packetHandlers = manager
130140
Expect(cl).ToNot(BeNil())
@@ -144,6 +154,7 @@ var _ = Describe("Client", func() {
144154
_ protocol.ConnectionID,
145155
_ protocol.ConnectionID,
146156
_ ConnectionIDGenerator,
157+
_ *statelessResetter,
147158
_ *Config,
148159
_ *tls.Config,
149160
_ protocol.PacketNumber,
@@ -161,7 +172,15 @@ var _ = Describe("Client", func() {
161172
return conn
162173
}
163174

164-
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, nil, true)
175+
cl, err := newClient(
176+
packetConn,
177+
&protocol.DefaultConnectionIDGenerator{},
178+
newStatelessResetter(nil),
179+
populateConfig(config),
180+
tlsConf,
181+
nil,
182+
true,
183+
)
165184
Expect(err).ToNot(HaveOccurred())
166185
cl.packetHandlers = manager
167186
Expect(cl).ToNot(BeNil())
@@ -181,6 +200,7 @@ var _ = Describe("Client", func() {
181200
_ protocol.ConnectionID,
182201
_ protocol.ConnectionID,
183202
_ ConnectionIDGenerator,
203+
_ *statelessResetter,
184204
_ *Config,
185205
_ *tls.Config,
186206
_ protocol.PacketNumber,
@@ -197,7 +217,13 @@ var _ = Describe("Client", func() {
197217
return conn
198218
}
199219
var closed bool
200-
cl, err := newClient(packetConn, &protocol.DefaultConnectionIDGenerator{}, populateConfig(config), tlsConf, func() { closed = true }, true)
220+
cl, err := newClient(
221+
packetConn,
222+
&protocol.DefaultConnectionIDGenerator{},
223+
newStatelessResetter(nil),
224+
populateConfig(config), tlsConf, func() { closed = true },
225+
true,
226+
)
201227
Expect(err).ToNot(HaveOccurred())
202228
cl.packetHandlers = manager
203229
Expect(cl).ToNot(BeNil())
@@ -266,6 +292,7 @@ var _ = Describe("Client", func() {
266292
_ protocol.ConnectionID,
267293
_ protocol.ConnectionID,
268294
_ ConnectionIDGenerator,
295+
_ *statelessResetter,
269296
configP *Config,
270297
_ *tls.Config,
271298
_ protocol.PacketNumber,
@@ -309,6 +336,7 @@ var _ = Describe("Client", func() {
309336
_ protocol.ConnectionID,
310337
connID protocol.ConnectionID,
311338
_ ConnectionIDGenerator,
339+
_ *statelessResetter,
312340
configP *Config,
313341
_ *tls.Config,
314342
pn protocol.PacketNumber,

conn_id_generator.go

+16-16
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,34 @@ type connIDGenerator struct {
1515
activeSrcConnIDs map[uint64]protocol.ConnectionID
1616
initialClientDestConnID *protocol.ConnectionID // nil for the client
1717

18-
addConnectionID func(protocol.ConnectionID)
19-
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
20-
removeConnectionID func(protocol.ConnectionID)
21-
retireConnectionID func(protocol.ConnectionID)
22-
replaceWithClosed func([]protocol.ConnectionID, []byte)
23-
queueControlFrame func(wire.Frame)
18+
addConnectionID func(protocol.ConnectionID)
19+
statelessResetter *statelessResetter
20+
removeConnectionID func(protocol.ConnectionID)
21+
retireConnectionID func(protocol.ConnectionID)
22+
replaceWithClosed func([]protocol.ConnectionID, []byte)
23+
queueControlFrame func(wire.Frame)
2424
}
2525

2626
func newConnIDGenerator(
2727
initialConnectionID protocol.ConnectionID,
2828
initialClientDestConnID *protocol.ConnectionID, // nil for the client
2929
addConnectionID func(protocol.ConnectionID),
30-
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
30+
statelessResetter *statelessResetter,
3131
removeConnectionID func(protocol.ConnectionID),
3232
retireConnectionID func(protocol.ConnectionID),
3333
replaceWithClosed func([]protocol.ConnectionID, []byte),
3434
queueControlFrame func(wire.Frame),
3535
generator ConnectionIDGenerator,
3636
) *connIDGenerator {
3737
m := &connIDGenerator{
38-
generator: generator,
39-
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
40-
addConnectionID: addConnectionID,
41-
getStatelessResetToken: getStatelessResetToken,
42-
removeConnectionID: removeConnectionID,
43-
retireConnectionID: retireConnectionID,
44-
replaceWithClosed: replaceWithClosed,
45-
queueControlFrame: queueControlFrame,
38+
generator: generator,
39+
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
40+
addConnectionID: addConnectionID,
41+
statelessResetter: statelessResetter,
42+
removeConnectionID: removeConnectionID,
43+
retireConnectionID: retireConnectionID,
44+
replaceWithClosed: replaceWithClosed,
45+
queueControlFrame: queueControlFrame,
4646
}
4747
m.activeSrcConnIDs[0] = initialConnectionID
4848
m.initialClientDestConnID = initialClientDestConnID
@@ -104,7 +104,7 @@ func (m *connIDGenerator) issueNewConnID() error {
104104
m.queueControlFrame(&wire.NewConnectionIDFrame{
105105
SequenceNumber: m.highestSeq + 1,
106106
ConnectionID: connID,
107-
StatelessResetToken: m.getStatelessResetToken(connID),
107+
StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID),
108108
})
109109
m.highestSeq++
110110
return nil

conn_id_generator_test.go

+4-7
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,11 @@ var _ = Describe("Connection ID Generator", func() {
1919
replacedWithClosed []protocol.ConnectionID
2020
queuedFrames []wire.Frame
2121
g *connIDGenerator
22+
statelessResetter *statelessResetter
2223
)
2324
initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
2425
initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe})
25-
26-
connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken {
27-
b := c.Bytes()[0]
28-
return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b}
29-
}
26+
statelessResetter = newStatelessResetter(nil)
3027

3128
BeforeEach(func() {
3229
addedConnIDs = nil
@@ -38,7 +35,7 @@ var _ = Describe("Connection ID Generator", func() {
3835
initialConnID,
3936
&initialClientDestConnID,
4037
func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
41-
connIDToToken,
38+
statelessResetter,
4239
func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
4340
func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
4441
func(cs []protocol.ConnectionID, _ []byte) { replacedWithClosed = append(replacedWithClosed, cs...) },
@@ -61,7 +58,7 @@ var _ = Describe("Connection ID Generator", func() {
6158
nf := f.(*wire.NewConnectionIDFrame)
6259
Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1))
6360
Expect(nf.ConnectionID.Len()).To(Equal(7))
64-
Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID)))
61+
Expect(nf.StatelessResetToken).To(Equal(statelessResetter.GetStatelessResetToken(nf.ConnectionID)))
6562
}
6663
})
6764

connection.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ func (p *receivedPacket) Clone() *receivedPacket {
8585

8686
type connRunner interface {
8787
Add(protocol.ConnectionID, packetHandler) bool
88-
GetStatelessResetToken(protocol.ConnectionID) protocol.StatelessResetToken
8988
Retire(protocol.ConnectionID)
9089
Remove(protocol.ConnectionID)
9190
ReplaceWithClosed([]protocol.ConnectionID, []byte)
@@ -225,7 +224,7 @@ var newConnection = func(
225224
destConnID protocol.ConnectionID,
226225
srcConnID protocol.ConnectionID,
227226
connIDGenerator ConnectionIDGenerator,
228-
statelessResetToken protocol.StatelessResetToken,
227+
statelessResetter *statelessResetter,
229228
conf *Config,
230229
tlsConf *tls.Config,
231230
tokenGenerator *handshake.TokenGenerator,
@@ -263,7 +262,7 @@ var newConnection = func(
263262
srcConnID,
264263
&clientDestConnID,
265264
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
266-
runner.GetStatelessResetToken,
265+
statelessResetter,
267266
runner.Remove,
268267
runner.Retire,
269268
runner.ReplaceWithClosed,
@@ -282,6 +281,7 @@ var newConnection = func(
282281
s.logger,
283282
)
284283
s.maxPayloadSizeEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
284+
statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID)
285285
params := &wire.TransportParameters{
286286
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
287287
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
@@ -340,6 +340,7 @@ var newClientConnection = func(
340340
destConnID protocol.ConnectionID,
341341
srcConnID protocol.ConnectionID,
342342
connIDGenerator ConnectionIDGenerator,
343+
statelessResetter *statelessResetter,
343344
conf *Config,
344345
tlsConf *tls.Config,
345346
initialPacketNumber protocol.PacketNumber,
@@ -372,7 +373,7 @@ var newClientConnection = func(
372373
srcConnID,
373374
nil,
374375
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
375-
runner.GetStatelessResetToken,
376+
statelessResetter,
376377
runner.Remove,
377378
runner.Retire,
378379
runner.ReplaceWithClosed,

connection_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func newServerTestConnection(
125125
protocol.ConnectionID{},
126126
srcConnID,
127127
&protocol.DefaultConnectionIDGenerator{},
128-
protocol.StatelessResetToken{},
128+
newStatelessResetter(nil),
129129
populateConfig(config),
130130
&tls.Config{},
131131
handshake.NewTokenGenerator(handshake.TokenProtectorKey{}),
@@ -180,6 +180,7 @@ func newClientTestConnection(
180180
destConnID,
181181
srcConnID,
182182
&protocol.DefaultConnectionIDGenerator{},
183+
newStatelessResetter(nil),
183184
populateConfig(config),
184185
&tls.Config{ServerName: "quic-go.net"},
185186
0,

mock_conn_runner_test.go

-38
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)