Skip to content

Commit a085e56

Browse files
iangudgergvisor-bot
authored andcommitted
Add support for SO_REUSEADDR to UDP sockets/endpoints.
On UDP sockets, SO_REUSEADDR allows multiple sockets to bind to the same address, but only delivers packets to the most recently bound socket. This differs from the behavior of SO_REUSEADDR on TCP sockets. SO_REUSEADDR for TCP sockets will likely need an almost completely independent implementation. SO_REUSEADDR has some odd interactions with the similar SO_REUSEPORT. These interactions are tested fairly extensively and all but one particularly odd one (that honestly seems like a bug) behave the same on gVisor and Linux. PiperOrigin-RevId: 315844832
1 parent a87c74b commit a085e56

17 files changed

+255
-178
lines changed

pkg/tcpip/ports/ports.go

+79-37
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,27 @@ type Flags struct {
5454
LoadBalanced bool
5555
}
5656

57-
func (f Flags) bits() reuseFlag {
58-
var rf reuseFlag
57+
// Bits converts the Flags to their bitset form.
58+
func (f Flags) Bits() BitFlags {
59+
var rf BitFlags
5960
if f.MostRecent {
60-
rf |= mostRecentFlag
61+
rf |= MostRecentFlag
6162
}
6263
if f.LoadBalanced {
63-
rf |= loadBalancedFlag
64+
rf |= LoadBalancedFlag
6465
}
6566
return rf
6667
}
6768

69+
// Effective returns the effective behavior of a flag config.
70+
func (f Flags) Effective() Flags {
71+
e := f
72+
if e.LoadBalanced && e.MostRecent {
73+
e.MostRecent = false
74+
}
75+
return e
76+
}
77+
6878
// PortManager manages allocating, reserving and releasing ports.
6979
type PortManager struct {
7080
mu sync.RWMutex
@@ -78,83 +88,115 @@ type PortManager struct {
7888
hint uint32
7989
}
8090

81-
type reuseFlag int
91+
// BitFlags is a bitset representation of Flags.
92+
type BitFlags uint32
8293

8394
const (
84-
mostRecentFlag reuseFlag = 1 << iota
85-
loadBalancedFlag
95+
// MostRecentFlag represents Flags.MostRecent.
96+
MostRecentFlag BitFlags = 1 << iota
97+
98+
// LoadBalancedFlag represents Flags.LoadBalanced.
99+
LoadBalancedFlag
100+
101+
// nextFlag is the value that the next added flag will have.
102+
//
103+
// It is used to calculate FlagMask below. It is also the number of
104+
// valid flag states.
86105
nextFlag
87106

88-
flagMask = nextFlag - 1
107+
// FlagMask is a bit mask for BitFlags.
108+
FlagMask = nextFlag - 1
89109
)
90110

91-
type portNode struct {
92-
// refs stores the count for each possible flag combination.
111+
// ToFlags converts the bitset into a Flags struct.
112+
func (f BitFlags) ToFlags() Flags {
113+
return Flags{
114+
MostRecent: f&MostRecentFlag != 0,
115+
LoadBalanced: f&LoadBalancedFlag != 0,
116+
}
117+
}
118+
119+
// FlagCounter counts how many references each flag combination has.
120+
type FlagCounter struct {
121+
// refs stores the count for each possible flag combination, (0 though
122+
// FlagMask).
93123
refs [nextFlag]int
94124
}
95125

96-
func (p portNode) totalRefs() int {
126+
// AddRef increases the reference count for a specific flag combination.
127+
func (c *FlagCounter) AddRef(flags BitFlags) {
128+
c.refs[flags]++
129+
}
130+
131+
// DropRef decreases the reference count for a specific flag combination.
132+
func (c *FlagCounter) DropRef(flags BitFlags) {
133+
c.refs[flags]--
134+
}
135+
136+
// TotalRefs calculates the total number of references for all flag
137+
// combinations.
138+
func (c FlagCounter) TotalRefs() int {
97139
var total int
98-
for _, r := range p.refs {
140+
for _, r := range c.refs {
99141
total += r
100142
}
101143
return total
102144
}
103145

104-
// flagRefs returns the number of references with all specified flags.
105-
func (p portNode) flagRefs(flags reuseFlag) int {
146+
// FlagRefs returns the number of references with all specified flags.
147+
func (c FlagCounter) FlagRefs(flags BitFlags) int {
106148
var total int
107-
for i, r := range p.refs {
108-
if reuseFlag(i)&flags == flags {
149+
for i, r := range c.refs {
150+
if BitFlags(i)&flags == flags {
109151
total += r
110152
}
111153
}
112154
return total
113155
}
114156

115-
// allRefsHave returns if all references have all specified flags.
116-
func (p portNode) allRefsHave(flags reuseFlag) bool {
117-
for i, r := range p.refs {
118-
if reuseFlag(i)&flags == flags && r > 0 {
157+
// AllRefsHave returns if all references have all specified flags.
158+
func (c FlagCounter) AllRefsHave(flags BitFlags) bool {
159+
for i, r := range c.refs {
160+
if BitFlags(i)&flags != flags && r > 0 {
119161
return false
120162
}
121163
}
122164
return true
123165
}
124166

125-
// intersectionRefs returns the set of flags shared by all references.
126-
func (p portNode) intersectionRefs() reuseFlag {
127-
intersection := flagMask
128-
for i, r := range p.refs {
167+
// IntersectionRefs returns the set of flags shared by all references.
168+
func (c FlagCounter) IntersectionRefs() BitFlags {
169+
intersection := FlagMask
170+
for i, r := range c.refs {
129171
if r > 0 {
130-
intersection &= reuseFlag(i)
172+
intersection &= BitFlags(i)
131173
}
132174
}
133175
return intersection
134176
}
135177

136178
// deviceNode is never empty. When it has no elements, it is removed from the
137179
// map that references it.
138-
type deviceNode map[tcpip.NICID]portNode
180+
type deviceNode map[tcpip.NICID]FlagCounter
139181

140182
// isAvailable checks whether binding is possible by device. If not binding to a
141-
// device, check against all portNodes. If binding to a specific device, check
183+
// device, check against all FlagCounters. If binding to a specific device, check
142184
// against the unspecified device and the provided device.
143185
//
144186
// If either of the port reuse flags is enabled on any of the nodes, all nodes
145187
// sharing a port must share at least one reuse flag. This matches Linux's
146188
// behavior.
147189
func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
148-
flagBits := flags.bits()
190+
flagBits := flags.Bits()
149191
if bindToDevice == 0 {
150192
// Trying to binding all devices.
151193
if flagBits == 0 {
152194
// Can't bind because the (addr,port) is already bound.
153195
return false
154196
}
155-
intersection := flagMask
197+
intersection := FlagMask
156198
for _, p := range d {
157-
i := p.intersectionRefs()
199+
i := p.IntersectionRefs()
158200
intersection &= i
159201
if intersection&flagBits == 0 {
160202
// Can't bind because the (addr,port) was
@@ -165,17 +207,17 @@ func (d deviceNode) isAvailable(flags Flags, bindToDevice tcpip.NICID) bool {
165207
return true
166208
}
167209

168-
intersection := flagMask
210+
intersection := FlagMask
169211

170212
if p, ok := d[0]; ok {
171-
intersection = p.intersectionRefs()
213+
intersection = p.IntersectionRefs()
172214
if intersection&flagBits == 0 {
173215
return false
174216
}
175217
}
176218

177219
if p, ok := d[bindToDevice]; ok {
178-
i := p.intersectionRefs()
220+
i := p.IntersectionRefs()
179221
intersection &= i
180222
if intersection&flagBits == 0 {
181223
return false
@@ -324,7 +366,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
324366
if !s.isPortAvailableLocked(networks, transport, addr, port, flags, bindToDevice) {
325367
return false
326368
}
327-
flagBits := flags.bits()
369+
flagBits := flags.Bits()
328370

329371
// Reserve port on all network protocols.
330372
for _, network := range networks {
@@ -340,7 +382,7 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber
340382
m[addr] = d
341383
}
342384
n := d[bindToDevice]
343-
n.refs[flagBits]++
385+
n.AddRef(flagBits)
344386
d[bindToDevice] = n
345387
}
346388

@@ -353,7 +395,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
353395
s.mu.Lock()
354396
defer s.mu.Unlock()
355397

356-
flagBits := flags.bits()
398+
flagBits := flags.Bits()
357399

358400
for _, network := range networks {
359401
desc := portDescriptor{network, transport, port}
@@ -368,7 +410,7 @@ func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transp
368410
}
369411
n.refs[flagBits]--
370412
d[bindToDevice] = n
371-
if n.refs == [nextFlag]int{} {
413+
if n.TotalRefs() == 0 {
372414
delete(d, bindToDevice)
373415
}
374416
if len(d) == 0 {

pkg/tcpip/stack/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ go_test(
8989
"//pkg/tcpip/link/loopback",
9090
"//pkg/tcpip/network/ipv4",
9191
"//pkg/tcpip/network/ipv6",
92+
"//pkg/tcpip/ports",
9293
"//pkg/tcpip/transport/icmp",
9394
"//pkg/tcpip/transport/udp",
9495
"//pkg/waiter",

pkg/tcpip/stack/stack.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -1404,25 +1404,25 @@ func (s *Stack) RemoveWaker(nicID tcpip.NICID, addr tcpip.Address, waker *sleep.
14041404
// transport dispatcher. Received packets that match the provided id will be
14051405
// delivered to the given endpoint; specifying a nic is optional, but
14061406
// nic-specific IDs have precedence over global ones.
1407-
func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error {
1408-
return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice)
1407+
func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
1408+
return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
14091409
}
14101410

14111411
// UnregisterTransportEndpoint removes the endpoint with the given id from the
14121412
// stack transport dispatcher.
1413-
func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
1414-
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
1413+
func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
1414+
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
14151415
}
14161416

14171417
// StartTransportEndpointCleanup removes the endpoint with the given id from
14181418
// the stack transport dispatcher. It also transitions it to the cleanup stage.
1419-
func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) {
1419+
func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
14201420
s.mu.Lock()
14211421
defer s.mu.Unlock()
14221422

14231423
s.cleanupEndpoints[ep] = struct{}{}
14241424

1425-
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice)
1425+
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
14261426
}
14271427

14281428
// CompleteTransportEndpointCleanup removes the endpoint from the cleanup

0 commit comments

Comments
 (0)