Skip to content

Commit 72945cd

Browse files
committed
proxy: add Dial (with context)
The existing API does not allow client code to take advantage of Dialer implementations that implement DialContext receivers. This a familiar API, see net.Dialer. Fixes golang/go#27874 Fixes golang/go#19354 Fixes golang/go#17759 Fixes golang/go#13455
1 parent 9ce7a69 commit 72945cd

File tree

7 files changed

+254
-18
lines changed

7 files changed

+254
-18
lines changed

proxy/dial.go

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright 2019 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package proxy
6+
7+
import (
8+
"context"
9+
"net"
10+
)
11+
12+
// A ContextDialer dials using a context.
13+
type ContextDialer interface {
14+
DialContext(ctx context.Context, network, address string) (net.Conn, error)
15+
}
16+
17+
// Dial works like DialContext on net.Dialer but using a dialer returned by FromEnvironment.
18+
//
19+
// The passed ctx is only used for returning the Conn, not the lifetime of the Conn.
20+
//
21+
// Custom dialers (registered via RegisterDialerType) that do not implement ContextDialer
22+
// can leak a goroutine for as long as it takes the underlying Dialer implementation to timeout.
23+
//
24+
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
25+
func Dial(ctx context.Context, network, address string) (net.Conn, error) {
26+
d := FromEnvironment()
27+
if xd, ok := d.(ContextDialer); ok {
28+
return xd.DialContext(ctx, network, address)
29+
}
30+
return dialContext(ctx, d, network, address)
31+
}
32+
33+
// WARNING: this can leak a goroutine for as long as the underlying Dialer implementation takes to timeout
34+
// A Conn returned from a successful Dial after the context has been cancelled will be immediately closed.
35+
func dialContext(ctx context.Context, d Dialer, network, address string) (net.Conn, error) {
36+
var (
37+
conn net.Conn
38+
done = make(chan struct{}, 1)
39+
err error
40+
)
41+
go func() {
42+
conn, err = d.Dial(network, address)
43+
close(done)
44+
if conn != nil && ctx.Err() != nil {
45+
conn.Close()
46+
}
47+
}()
48+
select {
49+
case <-ctx.Done():
50+
err = ctx.Err()
51+
case <-done:
52+
}
53+
return conn, err
54+
}

proxy/dial_test.go

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
// Copyright 2019 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package proxy
6+
7+
import (
8+
"context"
9+
"fmt"
10+
"net"
11+
"os"
12+
"testing"
13+
"time"
14+
15+
"golang.org/x/net/internal/sockstest"
16+
)
17+
18+
func TestDial(t *testing.T) {
19+
ResetProxyEnv()
20+
t.Run("DirectWithCancel", func(t *testing.T) {
21+
defer ResetProxyEnv()
22+
l, err := net.Listen("tcp", "127.0.0.1:0")
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
defer l.Close()
27+
_, port, err := net.SplitHostPort(l.Addr().String())
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
ctx, cancel := context.WithCancel(context.Background())
32+
defer cancel()
33+
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
34+
if err != nil {
35+
t.Fatal(err)
36+
}
37+
c.Close()
38+
})
39+
t.Run("DirectWithTimeout", func(t *testing.T) {
40+
defer ResetProxyEnv()
41+
l, err := net.Listen("tcp", "127.0.0.1:0")
42+
if err != nil {
43+
t.Fatal(err)
44+
}
45+
defer l.Close()
46+
_, port, err := net.SplitHostPort(l.Addr().String())
47+
if err != nil {
48+
t.Fatal(err)
49+
}
50+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
51+
defer cancel()
52+
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
53+
if err != nil {
54+
t.Fatal(err)
55+
}
56+
c.Close()
57+
})
58+
t.Run("DirectWithTimeoutExceeded", func(t *testing.T) {
59+
defer ResetProxyEnv()
60+
l, err := net.Listen("tcp", "127.0.0.1:0")
61+
if err != nil {
62+
t.Fatal(err)
63+
}
64+
defer l.Close()
65+
_, port, err := net.SplitHostPort(l.Addr().String())
66+
if err != nil {
67+
t.Fatal(err)
68+
}
69+
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
70+
time.Sleep(time.Millisecond)
71+
defer cancel()
72+
c, err := Dial(ctx, l.Addr().Network(), net.JoinHostPort("", port))
73+
if err == nil {
74+
defer c.Close()
75+
t.Fatal("failed to timeout")
76+
}
77+
})
78+
t.Run("SOCKS5", func(t *testing.T) {
79+
defer ResetProxyEnv()
80+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
81+
if err != nil {
82+
t.Fatal(err)
83+
}
84+
defer s.Close()
85+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
86+
t.Fatal(err)
87+
}
88+
c, err := Dial(context.Background(), s.TargetAddr().Network(), s.TargetAddr().String())
89+
if err != nil {
90+
t.Fatal(err)
91+
}
92+
c.Close()
93+
})
94+
t.Run("SOCKS5WithTimeout", func(t *testing.T) {
95+
defer ResetProxyEnv()
96+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
97+
if err != nil {
98+
t.Fatal(err)
99+
}
100+
defer s.Close()
101+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
102+
t.Fatal(err)
103+
}
104+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
105+
defer cancel()
106+
c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
107+
if err != nil {
108+
t.Fatal(err)
109+
}
110+
c.Close()
111+
})
112+
t.Run("SOCKS5WithTimeoutExceeded", func(t *testing.T) {
113+
defer ResetProxyEnv()
114+
s, err := sockstest.NewServer(sockstest.NoAuthRequired, sockstest.NoProxyRequired)
115+
if err != nil {
116+
t.Fatal(err)
117+
}
118+
defer s.Close()
119+
if err = os.Setenv("ALL_PROXY", fmt.Sprintf("socks5://%s", s.Addr().String())); err != nil {
120+
t.Fatal(err)
121+
}
122+
ctx, cancel := context.WithTimeout(context.Background(), time.Nanosecond)
123+
time.Sleep(time.Millisecond)
124+
defer cancel()
125+
c, err := Dial(ctx, s.TargetAddr().Network(), s.TargetAddr().String())
126+
if err == nil {
127+
defer c.Close()
128+
t.Fatal("failed to timeout")
129+
}
130+
})
131+
}

proxy/direct.go

+8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package proxy
66

77
import (
8+
"context"
89
"net"
910
)
1011

@@ -13,6 +14,13 @@ type direct struct{}
1314
// Direct is a direct proxy: one that makes network connections directly.
1415
var Direct = direct{}
1516

17+
// Dial directly invokes net.Dial with the supplied parameters.
1618
func (direct) Dial(network, addr string) (net.Conn, error) {
1719
return net.Dial(network, addr)
1820
}
21+
22+
// DialContext instantiates a net.Dialer and invokes its DialContext receiver with the supplied parameters.
23+
func (direct) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
24+
var d net.Dialer
25+
return d.DialContext(ctx, network, addr)
26+
}

proxy/per_host.go

+15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package proxy
66

77
import (
8+
"context"
89
"net"
910
"strings"
1011
)
@@ -41,6 +42,20 @@ func (p *PerHost) Dial(network, addr string) (c net.Conn, err error) {
4142
return p.dialerForRequest(host).Dial(network, addr)
4243
}
4344

45+
// DialContext connects to the address addr on the given network through either
46+
// defaultDialer or bypass.
47+
func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) {
48+
host, _, err := net.SplitHostPort(addr)
49+
if err != nil {
50+
return nil, err
51+
}
52+
d := p.dialerForRequest(host)
53+
if x, ok := d.(ContextDialer); ok {
54+
return x.DialContext(ctx, network, addr)
55+
}
56+
return dialContext(ctx, d, network, addr)
57+
}
58+
4459
func (p *PerHost) dialerForRequest(host string) Dialer {
4560
if ip := net.ParseIP(host); ip != nil {
4661
for _, net := range p.bypassNetworks {

proxy/per_host_test.go

+37-16
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package proxy
66

77
import (
8+
"context"
89
"errors"
910
"net"
1011
"reflect"
@@ -21,10 +22,6 @@ func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) {
2122
}
2223

2324
func TestPerHost(t *testing.T) {
24-
var def, bypass recordingProxy
25-
perHost := NewPerHost(&def, &bypass)
26-
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
27-
2825
expectedDef := []string{
2926
"example.com:123",
3027
"1.2.3.4:123",
@@ -39,17 +36,41 @@ func TestPerHost(t *testing.T) {
3936
"[1000::]:123",
4037
}
4138

42-
for _, addr := range expectedDef {
43-
perHost.Dial("tcp", addr)
44-
}
45-
for _, addr := range expectedBypass {
46-
perHost.Dial("tcp", addr)
47-
}
39+
t.Run("Dial", func(t *testing.T) {
40+
var def, bypass recordingProxy
41+
perHost := NewPerHost(&def, &bypass)
42+
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
43+
for _, addr := range expectedDef {
44+
perHost.Dial("tcp", addr)
45+
}
46+
for _, addr := range expectedBypass {
47+
perHost.Dial("tcp", addr)
48+
}
4849

49-
if !reflect.DeepEqual(expectedDef, def.addrs) {
50-
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
51-
}
52-
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
53-
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
54-
}
50+
if !reflect.DeepEqual(expectedDef, def.addrs) {
51+
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
52+
}
53+
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
54+
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
55+
}
56+
})
57+
58+
t.Run("DialContext", func(t *testing.T) {
59+
var def, bypass recordingProxy
60+
perHost := NewPerHost(&def, &bypass)
61+
perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16")
62+
for _, addr := range expectedDef {
63+
perHost.DialContext(context.Background(), "tcp", addr)
64+
}
65+
for _, addr := range expectedBypass {
66+
perHost.DialContext(context.Background(), "tcp", addr)
67+
}
68+
69+
if !reflect.DeepEqual(expectedDef, def.addrs) {
70+
t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef)
71+
}
72+
if !reflect.DeepEqual(expectedBypass, bypass.addrs) {
73+
t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass)
74+
}
75+
})
5576
}

proxy/proxy.go

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
)
1616

1717
// A Dialer is a means to establish a connection.
18+
// Custom dialers should also implement ContextDialer.
1819
type Dialer interface {
1920
// Dial connects to the given address via the proxy.
2021
Dial(network, addr string) (c net.Conn, err error)

proxy/socks5.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,14 @@ import (
1717
func SOCKS5(network, address string, auth *Auth, forward Dialer) (Dialer, error) {
1818
d := socks.NewDialer(network, address)
1919
if forward != nil {
20-
d.ProxyDial = func(_ context.Context, network string, address string) (net.Conn, error) {
21-
return forward.Dial(network, address)
20+
if f, ok := forward.(ContextDialer); ok {
21+
d.ProxyDial = func(ctx context.Context, network string, address string) (net.Conn, error) {
22+
return f.DialContext(ctx, network, address)
23+
}
24+
} else {
25+
d.ProxyDial = func(_ context.Context, network string, address string) (net.Conn, error) {
26+
return forward.Dial(network, address)
27+
}
2228
}
2329
}
2430
if auth != nil {

0 commit comments

Comments
 (0)