Skip to content

Commit 24a796c

Browse files
committed
ssh: add (*Client).DialContext method
Fixes golang/go#20288.
1 parent 5fe8145 commit 24a796c

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

ssh/tcpip.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package ssh
66

77
import (
8+
"context"
89
"errors"
910
"fmt"
1011
"io"
@@ -332,6 +333,37 @@ func (l *tcpListener) Addr() net.Addr {
332333
return l.laddr
333334
}
334335

336+
// DialContext initiates a connection to the addr from the remote host.
337+
// If the supplied context is cancelled before the connection can be opened,
338+
// ctx.Err() will be returned.
339+
// The resulting connection has a zero LocalAddr() and RemoteAddr().
340+
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
341+
if err := ctx.Err(); err != nil {
342+
return nil, err
343+
}
344+
type connErr struct {
345+
conn net.Conn
346+
err error
347+
}
348+
ch := make(chan connErr)
349+
go func() {
350+
conn, err := c.Dial(n, addr)
351+
select {
352+
case ch <- connErr{conn, err}:
353+
case <-ctx.Done():
354+
if conn != nil {
355+
conn.Close()
356+
}
357+
}
358+
}()
359+
select {
360+
case res := <-ch:
361+
return res.conn, res.err
362+
case <-ctx.Done():
363+
return nil, ctx.Err()
364+
}
365+
}
366+
335367
// Dial initiates a connection to the addr from the remote host.
336368
// The resulting connection has a zero LocalAddr() and RemoteAddr().
337369
func (c *Client) Dial(n, addr string) (net.Conn, error) {

ssh/tcpip_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
package ssh
66

77
import (
8+
"context"
9+
"net"
810
"testing"
11+
"time"
912
)
1013

1114
func TestAutoPortListenBroken(t *testing.T) {
@@ -18,3 +21,40 @@ func TestAutoPortListenBroken(t *testing.T) {
1821
t.Errorf("version %q marked as broken", works)
1922
}
2023
}
24+
25+
func TestClientImplementsDialContext(t *testing.T) {
26+
type ContextDialer interface {
27+
DialContext(context.Context, string, string) (net.Conn, error)
28+
}
29+
var _ ContextDialer = &Client{}
30+
}
31+
32+
func TestClientDialContextWithCancel(t *testing.T) {
33+
c := &Client{}
34+
ctx, cancel := context.WithCancel(context.Background())
35+
cancel()
36+
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
37+
if err != context.Canceled {
38+
t.Errorf("DialContext: got nil error, expected %v", context.Canceled)
39+
}
40+
}
41+
42+
func TestClientDialContextWithDeadline(t *testing.T) {
43+
c := &Client{}
44+
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
45+
defer cancel()
46+
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
47+
if err != context.DeadlineExceeded {
48+
t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
49+
}
50+
}
51+
52+
func TestClientDialContextWithTimeout(t *testing.T) {
53+
c := &Client{}
54+
ctx, cancel := context.WithTimeout(context.Background(), 0)
55+
defer cancel()
56+
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
57+
if err != context.DeadlineExceeded {
58+
t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
59+
}
60+
}

0 commit comments

Comments
 (0)