Skip to content
This repository was archived by the owner on May 20, 2020. It is now read-only.

Commit f33542f

Browse files
sepehrtheraissshentubot
authored andcommitted
Block for link address resolution
Previously, if address resolution for UDP or Ping sockets required sending packets using Write in Transport layer, Resolve would return ErrWouldBlock and Write would return ErrNoLinkAddress. Meanwhile startAddressResolution would run in background. Further calls to Write using same address would also return ErrNoLinkAddress until resolution has been completed successfully. Since Write is not allowed to block and System Calls need to be interruptible in System Call layer, the caller to Write is responsible for blocking upon return of ErrWouldBlock. Now, when startAddressResolution is called a notification channel for the completion of the address resolution is returned. The channel will traverse up to the calling function of Write as well as ErrNoLinkAddress. Once address resolution is complete (success or not) the channel is closed. The caller would call Write again to send packets and check if address resolution was compeleted successfully or not. Fixes google/gvisor#5 Change-Id: Idafaf31982bee1915ca084da39ae7bd468cebd93 PiperOrigin-RevId: 214962200
1 parent 9029f56 commit f33542f

19 files changed

+200
-126
lines changed

dhcp/client.go

+15-2
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,23 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg
195195
wopts := tcpip.WriteOptions{
196196
To: serverAddr,
197197
}
198-
if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
198+
var resCh <-chan struct{}
199+
if _, resCh, err = ep.Write(tcpip.SlicePayload(h), wopts); err != nil && resCh == nil {
199200
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
200201
}
201202

203+
if resCh != nil {
204+
select {
205+
case <-resCh:
206+
case <-ctx.Done():
207+
return Config{}, fmt.Errorf("dhcp client address resolution: %v", tcpip.ErrAborted)
208+
}
209+
210+
if _, _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
211+
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
212+
}
213+
}
214+
202215
we, ch := waiter.NewChannelEntry(nil)
203216
wq.EventRegister(&we, waiter.EventIn)
204217
defer wq.EventUnregister(&we)
@@ -289,7 +302,7 @@ func (c *Client) Request(ctx context.Context, requestedAddr tcpip.Address) (cfg
289302
reqOpts = append(reqOpts, option{optClientID, clientID})
290303
}
291304
h.setOptions(reqOpts)
292-
if _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
305+
if _, _, err := ep.Write(tcpip.SlicePayload(h), wopts); err != nil {
293306
return Config{}, fmt.Errorf("dhcp discovery write: %v", err)
294307
}
295308

dhcp/server.go

+14-1
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,22 @@ func (c *epConn) Read() (buffer.View, tcpip.FullAddress, error) {
9595
}
9696

9797
func (c *epConn) Write(b []byte, addr *tcpip.FullAddress) error {
98-
if _, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr}); err != nil {
98+
_, resCh, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr})
99+
if err != nil && resCh == nil {
99100
return fmt.Errorf("write: %v", err)
100101
}
102+
103+
if resCh != nil {
104+
select {
105+
case <-resCh:
106+
case <-c.ctx.Done():
107+
return fmt.Errorf("dhcp server address resolution: %v", tcpip.ErrAborted)
108+
}
109+
110+
if _, _, err := c.ep.Write(tcpip.SlicePayload(b), tcpip.WriteOptions{To: addr}); err != nil {
111+
return fmt.Errorf("write: %v", err)
112+
}
113+
}
101114
return nil
102115
}
103116

tcpip/adapters/gonet/gonet.go

+29-6
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,22 @@ func (c *Conn) Write(b []byte) (int, error) {
393393
}
394394

395395
var n uintptr
396-
n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
396+
var resCh <-chan struct{}
397+
n, resCh, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
397398
nbytes += int(n)
398399
v.TrimFront(int(n))
400+
401+
if resCh != nil {
402+
select {
403+
case <-deadline:
404+
return nbytes, c.newOpError("write", &timeoutError{})
405+
case <-resCh:
406+
}
407+
408+
n, _, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
409+
nbytes += int(n)
410+
v.TrimFront(int(n))
411+
}
399412
}
400413

401414
if err == nil {
@@ -571,23 +584,33 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
571584
copy(v, b)
572585

573586
wopts := tcpip.WriteOptions{To: &fullAddr}
574-
n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
587+
n, resCh, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
588+
if resCh != nil {
589+
select {
590+
case <-deadline:
591+
return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
592+
case <-resCh:
593+
}
594+
595+
n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
596+
}
575597

576598
if err == tcpip.ErrWouldBlock {
577599
// Create wait queue entry that notifies a channel.
578600
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
579601
c.wq.EventRegister(&waitEntry, waiter.EventOut)
580602
defer c.wq.EventUnregister(&waitEntry)
581603
for {
582-
n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
583-
if err != tcpip.ErrWouldBlock {
584-
break
585-
}
586604
select {
587605
case <-deadline:
588606
return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
589607
case <-notifyCh:
590608
}
609+
610+
n, _, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
611+
if err != tcpip.ErrWouldBlock {
612+
break
613+
}
591614
}
592615
}
593616

tcpip/network/ipv6/icmp_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ func TestLinkResolution(t *testing.T) {
190190
if ctx.Err() != nil {
191191
break
192192
}
193-
if _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}); err == tcpip.ErrNoLinkAddress {
193+
if _, _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}); err == tcpip.ErrNoLinkAddress {
194194
// There's something asynchronous going on; yield to let it do its thing.
195195
runtime.Gosched()
196196
} else if err == nil {

tcpip/sample/tun_tcp_connect/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func writer(ch chan struct{}, ep tcpip.Endpoint) {
8080

8181
v.CapLength(n)
8282
for len(v) > 0 {
83-
n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
83+
n, _, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
8484
if err != nil {
8585
fmt.Println("Write failed:", err)
8686
return

tcpip/stack/linkaddrcache.go

+29-11
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,14 @@ type linkAddrEntry struct {
8888
linkAddr tcpip.LinkAddress
8989
expiration time.Time
9090
s entryState
91+
resDone bool
9192

9293
// wakers is a set of waiters for address resolution result. Anytime
9394
// state transitions out of 'incomplete' these waiters are notified.
9495
wakers map[*sleep.Waker]struct{}
9596

9697
cancel chan struct{}
98+
resCh chan struct{}
9799
}
98100

99101
func (e *linkAddrEntry) state() entryState {
@@ -182,15 +184,20 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress
182184
// someone waiting for address resolution on it.
183185
entry.changeState(expired)
184186
if entry.cancel != nil {
185-
entry.cancel <- struct{}{}
187+
if !entry.resDone {
188+
close(entry.resCh)
189+
}
190+
close(entry.cancel)
186191
}
187192

188193
*entry = linkAddrEntry{
189194
addr: k,
190195
linkAddr: v,
191196
expiration: time.Now().Add(c.ageLimit),
197+
resDone: false,
192198
wakers: make(map[*sleep.Waker]struct{}),
193199
cancel: make(chan struct{}, 1),
200+
resCh: make(chan struct{}, 1),
194201
}
195202

196203
c.cache[k] = entry
@@ -202,10 +209,10 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress
202209
}
203210

204211
// get reports any known link address for k.
205-
func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
212+
func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
206213
if linkRes != nil {
207214
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
208-
return addr, nil
215+
return addr, nil, nil
209216
}
210217
}
211218

@@ -214,10 +221,11 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
214221
if entry == nil || entry.state() == expired {
215222
c.mu.Unlock()
216223
if linkRes == nil {
217-
return "", tcpip.ErrNoLinkAddress
224+
return "", nil, tcpip.ErrNoLinkAddress
218225
}
219-
c.startAddressResolution(k, linkRes, localAddr, linkEP, waker)
220-
return "", tcpip.ErrWouldBlock
226+
227+
ch := c.startAddressResolution(k, linkRes, localAddr, linkEP, waker)
228+
return "", ch, tcpip.ErrWouldBlock
221229
}
222230
defer c.mu.Unlock()
223231

@@ -227,13 +235,13 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo
227235
// in that case it's safe to consider it ready.
228236
fallthrough
229237
case ready:
230-
return entry.linkAddr, nil
238+
return entry.linkAddr, nil, nil
231239
case failed:
232-
return "", tcpip.ErrNoLinkAddress
240+
return "", nil, tcpip.ErrNoLinkAddress
233241
case incomplete:
234242
// Address resolution is still in progress.
235243
entry.addWaker(waker)
236-
return "", tcpip.ErrWouldBlock
244+
return "", entry.resCh, tcpip.ErrWouldBlock
237245
default:
238246
panic(fmt.Sprintf("invalid cache entry state: %d", s))
239247
}
@@ -249,13 +257,13 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
249257
}
250258
}
251259

252-
func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) {
260+
func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) <-chan struct{} {
253261
c.mu.Lock()
254262
defer c.mu.Unlock()
255263

256264
// Look up again with lock held to ensure entry wasn't added by someone else.
257265
if e := c.cache[k]; e != nil && e.state() != expired {
258-
return
266+
return nil
259267
}
260268

261269
// Add 'incomplete' entry in the cache to mark that resolution is in progress.
@@ -274,13 +282,23 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link
274282
select {
275283
case <-time.After(c.resolutionTimeout):
276284
if stop := c.checkLinkRequest(k, i); stop {
285+
// If entry is evicted then resCh is already closed.
286+
c.mu.Lock()
287+
if e, ok := c.cache[k]; ok {
288+
if !e.resDone {
289+
e.resDone = true
290+
close(e.resCh)
291+
}
292+
}
293+
c.mu.Unlock()
277294
return
278295
}
279296
case <-cancel:
280297
return
281298
}
282299
}
283300
}()
301+
return e.resCh
284302
}
285303

286304
// checkLinkRequest checks whether previous attempt to resolve address has succeeded

tcpip/stack/linkaddrcache_test.go

+11-11
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe
7373
defer s.Done()
7474

7575
for {
76-
if got, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
76+
if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
7777
return got, err
7878
}
7979
s.Fetch(true)
@@ -95,7 +95,7 @@ func TestCacheOverflow(t *testing.T) {
9595
for i := len(testaddrs) - 1; i >= 0; i-- {
9696
e := testaddrs[i]
9797
c.add(e.addr, e.linkAddr)
98-
got, err := c.get(e.addr, nil, "", nil, nil)
98+
got, _, err := c.get(e.addr, nil, "", nil, nil)
9999
if err != nil {
100100
t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
101101
}
@@ -106,7 +106,7 @@ func TestCacheOverflow(t *testing.T) {
106106
// Expect to find at least half of the most recent entries.
107107
for i := 0; i < linkAddrCacheSize/2; i++ {
108108
e := testaddrs[i]
109-
got, err := c.get(e.addr, nil, "", nil, nil)
109+
got, _, err := c.get(e.addr, nil, "", nil, nil)
110110
if err != nil {
111111
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
112112
}
@@ -117,7 +117,7 @@ func TestCacheOverflow(t *testing.T) {
117117
// The earliest entries should no longer be in the cache.
118118
for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
119119
e := testaddrs[i]
120-
if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
120+
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
121121
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
122122
}
123123
}
@@ -143,7 +143,7 @@ func TestCacheConcurrent(t *testing.T) {
143143
// can fit in the cache, so our eviction strategy requires that
144144
// the last entry be present and the first be missing.
145145
e := testaddrs[len(testaddrs)-1]
146-
got, err := c.get(e.addr, nil, "", nil, nil)
146+
got, _, err := c.get(e.addr, nil, "", nil, nil)
147147
if err != nil {
148148
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
149149
}
@@ -152,7 +152,7 @@ func TestCacheConcurrent(t *testing.T) {
152152
}
153153

154154
e = testaddrs[0]
155-
if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
155+
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
156156
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
157157
}
158158
}
@@ -162,7 +162,7 @@ func TestCacheAgeLimit(t *testing.T) {
162162
e := testaddrs[0]
163163
c.add(e.addr, e.linkAddr)
164164
time.Sleep(50 * time.Millisecond)
165-
if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
165+
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
166166
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
167167
}
168168
}
@@ -172,7 +172,7 @@ func TestCacheReplace(t *testing.T) {
172172
e := testaddrs[0]
173173
l2 := e.linkAddr + "2"
174174
c.add(e.addr, e.linkAddr)
175-
got, err := c.get(e.addr, nil, "", nil, nil)
175+
got, _, err := c.get(e.addr, nil, "", nil, nil)
176176
if err != nil {
177177
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
178178
}
@@ -181,7 +181,7 @@ func TestCacheReplace(t *testing.T) {
181181
}
182182

183183
c.add(e.addr, l2)
184-
got, err = c.get(e.addr, nil, "", nil, nil)
184+
got, _, err = c.get(e.addr, nil, "", nil, nil)
185185
if err != nil {
186186
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
187187
}
@@ -206,7 +206,7 @@ func TestCacheResolution(t *testing.T) {
206206
// Check that after resolved, address stays in the cache and never returns WouldBlock.
207207
for i := 0; i < 10; i++ {
208208
e := testaddrs[len(testaddrs)-1]
209-
got, err := c.get(e.addr, linkRes, "", nil, nil)
209+
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
210210
if err != nil {
211211
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
212212
}
@@ -256,7 +256,7 @@ func TestStaticResolution(t *testing.T) {
256256

257257
addr := tcpip.Address("broadcast")
258258
want := tcpip.LinkAddress("mac_broadcast")
259-
got, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
259+
got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
260260
if err != nil {
261261
t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
262262
}

tcpip/stack/registration.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ type LinkAddressCache interface {
289289
// registered with the network protocol, the cache attempts to resolve the address
290290
// and returns ErrWouldBlock. Waker is notified when address resolution is
291291
// complete (success or not).
292-
GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error)
292+
//
293+
// If address resolution is required, ErrNoLinkAddress and a notification channel is
294+
// returned for the top level caller to block. Channel is closed once address resolution
295+
// is complete (success or not).
296+
GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
293297

294298
// RemoveWaker removes a waker that has been added in GetLinkAddress().
295299
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)

0 commit comments

Comments
 (0)