Skip to content

Commit 94b6652

Browse files
committed
ssh: add (*Client).DialContext method
Fixes golang/go#20288.
1 parent 5fe8145 commit 94b6652

File tree

2 files changed

+75
-3
lines changed

2 files changed

+75
-3
lines changed

ssh/tcpip.go

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package ssh
66

77
import (
8+
"context"
89
"errors"
910
"fmt"
1011
"io"
@@ -332,6 +333,37 @@ func (l *tcpListener) Addr() net.Addr {
332333
return l.laddr
333334
}
334335

336+
// DialContext initiates a connection to the addr from the remote host.
337+
// If the supplied context is cancelled before the connection can be opened,
338+
// ctx.Err() will be returned.
339+
// The resulting connection has a zero LocalAddr() and RemoteAddr().
340+
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
341+
if err := ctx.Err(); err != nil {
342+
return nil, err
343+
}
344+
type connErr struct {
345+
conn net.Conn
346+
err error
347+
}
348+
ch := make(chan connErr)
349+
go func() {
350+
conn, err := c.Dial(n, addr)
351+
select {
352+
case ch <- connErr{conn, err}:
353+
case <-ctx.Done():
354+
if conn != nil {
355+
conn.Close()
356+
}
357+
}
358+
}()
359+
select {
360+
case res := <-ch:
361+
return res.conn, res.err
362+
case <-ctx.Done():
363+
return nil, ctx.Err()
364+
}
365+
}
366+
335367
// Dial initiates a connection to the addr from the remote host.
336368
// The resulting connection has a zero LocalAddr() and RemoteAddr().
337369
func (c *Client) Dial(n, addr string) (net.Conn, error) {
@@ -347,7 +379,7 @@ func (c *Client) Dial(n, addr string) (net.Conn, error) {
347379
if err != nil {
348380
return nil, err
349381
}
350-
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
382+
ch, err = c.dialTCP(net.IPv4zero.String(), 0, host, int(port))
351383
if err != nil {
352384
return nil, err
353385
}
@@ -393,7 +425,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
393425
Port: 0,
394426
}
395427
}
396-
ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
428+
ch, err := c.dialTCP(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port)
397429
if err != nil {
398430
return nil, err
399431
}
@@ -412,7 +444,7 @@ type channelOpenDirectMsg struct {
412444
lport uint32
413445
}
414446

415-
func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) {
447+
func (c *Client) dialTCP(laddr string, lport int, raddr string, rport int) (Channel, error) {
416448
msg := channelOpenDirectMsg{
417449
raddr: raddr,
418450
rport: uint32(rport),

ssh/tcpip_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
package ssh
66

77
import (
8+
"context"
9+
"net"
810
"testing"
11+
"time"
912
)
1013

1114
func TestAutoPortListenBroken(t *testing.T) {
@@ -18,3 +21,40 @@ func TestAutoPortListenBroken(t *testing.T) {
1821
t.Errorf("version %q marked as broken", works)
1922
}
2023
}
24+
25+
func TestClientImplementsDialContext(t *testing.T) {
26+
type ContextDialer interface {
27+
DialContext(context.Context, string, string) (net.Conn, error)
28+
}
29+
var _ ContextDialer = &Client{}
30+
}
31+
32+
func TestClientDialContextWithCancel(t *testing.T) {
33+
c := &Client{}
34+
ctx, cancel := context.WithCancel(context.Background())
35+
cancel()
36+
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
37+
if err != context.Canceled {
38+
t.Errorf("DialContext: got nil error, expected %v", context.Canceled)
39+
}
40+
}
41+
42+
func TestClientDialContextWithDeadline(t *testing.T) {
43+
c := &Client{}
44+
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
45+
defer cancel()
46+
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
47+
if err != context.DeadlineExceeded {
48+
t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
49+
}
50+
}
51+
52+
func TestClientDialContextWithTimeout(t *testing.T) {
53+
c := &Client{}
54+
ctx, cancel := context.WithTimeout(context.Background(), 0)
55+
defer cancel()
56+
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
57+
if err != context.DeadlineExceeded {
58+
t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
59+
}
60+
}

0 commit comments

Comments
 (0)