5
5
package ssh
6
6
7
7
import (
8
+ "context"
8
9
"errors"
9
10
"fmt"
10
11
"io"
@@ -332,6 +333,37 @@ func (l *tcpListener) Addr() net.Addr {
332
333
return l .laddr
333
334
}
334
335
336
+ // DialContext initiates a connection to the addr from the remote host.
337
+ // If the supplied context is cancelled before the connection can be opened,
338
+ // ctx.Err() will be returned.
339
+ // The resulting connection has a zero LocalAddr() and RemoteAddr().
340
+ func (c * Client ) DialContext (ctx context.Context , n , addr string ) (net.Conn , error ) {
341
+ if err := ctx .Err (); err != nil {
342
+ return nil , err
343
+ }
344
+ type connErr struct {
345
+ conn net.Conn
346
+ err error
347
+ }
348
+ ch := make (chan connErr )
349
+ go func () {
350
+ conn , err := c .Dial (n , addr )
351
+ select {
352
+ case ch <- connErr {conn , err }:
353
+ case <- ctx .Done ():
354
+ if conn != nil {
355
+ conn .Close ()
356
+ }
357
+ }
358
+ }()
359
+ select {
360
+ case res := <- ch :
361
+ return res .conn , res .err
362
+ case <- ctx .Done ():
363
+ return nil , ctx .Err ()
364
+ }
365
+ }
366
+
335
367
// Dial initiates a connection to the addr from the remote host.
336
368
// The resulting connection has a zero LocalAddr() and RemoteAddr().
337
369
func (c * Client ) Dial (n , addr string ) (net.Conn , error ) {
@@ -347,7 +379,7 @@ func (c *Client) Dial(n, addr string) (net.Conn, error) {
347
379
if err != nil {
348
380
return nil , err
349
381
}
350
- ch , err = c .dial (net .IPv4zero .String (), 0 , host , int (port ))
382
+ ch , err = c .dialTCP (net .IPv4zero .String (), 0 , host , int (port ))
351
383
if err != nil {
352
384
return nil , err
353
385
}
@@ -393,7 +425,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
393
425
Port : 0 ,
394
426
}
395
427
}
396
- ch , err := c .dial (laddr .IP .String (), laddr .Port , raddr .IP .String (), raddr .Port )
428
+ ch , err := c .dialTCP (laddr .IP .String (), laddr .Port , raddr .IP .String (), raddr .Port )
397
429
if err != nil {
398
430
return nil , err
399
431
}
@@ -412,7 +444,7 @@ type channelOpenDirectMsg struct {
412
444
lport uint32
413
445
}
414
446
415
- func (c * Client ) dial (laddr string , lport int , raddr string , rport int ) (Channel , error ) {
447
+ func (c * Client ) dialTCP (laddr string , lport int , raddr string , rport int ) (Channel , error ) {
416
448
msg := channelOpenDirectMsg {
417
449
raddr : raddr ,
418
450
rport : uint32 (rport ),
0 commit comments