Skip to content

Commit b2d7c26

Browse files
ydnargopherbot
authored andcommitted
ssh: add (*Client).DialContext method
This change adds DialContext to ssh.Client, which opens a TCP-IP connection tunneled over the SSH connection. This is useful for proxying network connections, e.g. setting (net/http.Transport).DialContext. Fixes golang/go#20288. Change-Id: I110494c00962424ea803065535ebe2209364ac27 GitHub-Last-Rev: 3176984 GitHub-Pull-Request: #260 Reviewed-on: https://go-review.googlesource.com/c/crypto/+/504735 Run-TryBot: Nicola Murino <[email protected]> Run-TryBot: Han-Wen Nienhuys <[email protected]> Auto-Submit: Nicola Murino <[email protected]> Reviewed-by: Han-Wen Nienhuys <[email protected]> Reviewed-by: Dmitri Shuralyov <[email protected]> TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Nicola Murino <[email protected]> Commit-Queue: Nicola Murino <[email protected]>
1 parent 1c17e20 commit b2d7c26

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

ssh/tcpip.go

+35
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package ssh
66

77
import (
8+
"context"
89
"errors"
910
"fmt"
1011
"io"
@@ -332,6 +333,40 @@ func (l *tcpListener) Addr() net.Addr {
332333
return l.laddr
333334
}
334335

336+
// DialContext initiates a connection to the addr from the remote host.
337+
//
338+
// The provided Context must be non-nil. If the context expires before the
339+
// connection is complete, an error is returned. Once successfully connected,
340+
// any expiration of the context will not affect the connection.
341+
//
342+
// See func Dial for additional information.
343+
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
344+
if err := ctx.Err(); err != nil {
345+
return nil, err
346+
}
347+
type connErr struct {
348+
conn net.Conn
349+
err error
350+
}
351+
ch := make(chan connErr)
352+
go func() {
353+
conn, err := c.Dial(n, addr)
354+
select {
355+
case ch <- connErr{conn, err}:
356+
case <-ctx.Done():
357+
if conn != nil {
358+
conn.Close()
359+
}
360+
}
361+
}()
362+
select {
363+
case res := <-ch:
364+
return res.conn, res.err
365+
case <-ctx.Done():
366+
return nil, ctx.Err()
367+
}
368+
}
369+
335370
// Dial initiates a connection to the addr from the remote host.
336371
// The resulting connection has a zero LocalAddr() and RemoteAddr().
337372
func (c *Client) Dial(n, addr string) (net.Conn, error) {

ssh/tcpip_test.go

+33
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
package ssh
66

77
import (
8+
"context"
9+
"net"
810
"testing"
11+
"time"
912
)
1013

1114
func TestAutoPortListenBroken(t *testing.T) {
@@ -18,3 +21,33 @@ func TestAutoPortListenBroken(t *testing.T) {
1821
t.Errorf("version %q marked as broken", works)
1922
}
2023
}
24+
25+
func TestClientImplementsDialContext(t *testing.T) {
26+
type ContextDialer interface {
27+
DialContext(context.Context, string, string) (net.Conn, error)
28+
}
29+
// Belt and suspenders assertion, since package net does not
30+
// declare a ContextDialer type.
31+
var _ ContextDialer = &net.Dialer{}
32+
var _ ContextDialer = &Client{}
33+
}
34+
35+
func TestClientDialContextWithCancel(t *testing.T) {
36+
c := &Client{}
37+
ctx, cancel := context.WithCancel(context.Background())
38+
cancel()
39+
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
40+
if err != context.Canceled {
41+
t.Errorf("DialContext: got nil error, expected %v", context.Canceled)
42+
}
43+
}
44+
45+
func TestClientDialContextWithDeadline(t *testing.T) {
46+
c := &Client{}
47+
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
48+
defer cancel()
49+
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
50+
if err != context.DeadlineExceeded {
51+
t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
52+
}
53+
}

ssh/test/dial_unix_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package test
99
// direct-tcpip and direct-streamlocal functional tests
1010

1111
import (
12+
"context"
1213
"fmt"
1314
"io"
1415
"net"
@@ -46,7 +47,11 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) {
4647
}
4748
}()
4849

49-
conn, err := sshConn.Dial(n, l.Addr().String())
50+
ctx, cancel := context.WithCancel(context.Background())
51+
conn, err := sshConn.DialContext(ctx, n, l.Addr().String())
52+
// Canceling the context after dial should have no effect
53+
// on the opened connection.
54+
cancel()
5055
if err != nil {
5156
t.Fatalf("Dial: %v", err)
5257
}

0 commit comments

Comments
 (0)