Skip to content

[communicator] add proxy_command support to connection block #36643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changes/v1.12/ENHANCEMENTS-20250301-165740.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
kind: ENHANCEMENTS
body: You can now utilize an SSH proxy command with the `proxy_command` argument in `connection` blocks of resources.
time: 2025-03-01T16:57:40.050317317Z
custom:
Issue: "394"
4 changes: 4 additions & 0 deletions internal/communicator/shared/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ var ConnectionBlockSupersetSchema = &configschema.Block{
Type: cty.String,
Optional: true,
},
"proxy_command": {
Type: cty.String,
Optional: true,
},
"bastion_host": {
Type: cty.String,
Optional: true,
Expand Down
160 changes: 157 additions & 3 deletions internal/communicator/ssh/communicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"math/rand"
"net"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
Expand Down Expand Up @@ -174,7 +175,12 @@ func (c *Communicator) Connect(o provisioners.UIOutput) (err error) {
))
}

if c.connInfo.ProxyHost != "" {
if c.connInfo.ProxyCommand != "" {
o.Output(fmt.Sprintf(
"Using configured proxy command: %s",
c.connInfo.ProxyCommand,
))
} else if c.connInfo.ProxyHost != "" {
o.Output(fmt.Sprintf(
"Using configured proxy host...\n"+
" ProxyHost: %s\n"+
Expand Down Expand Up @@ -329,7 +335,18 @@ func (c *Communicator) Disconnect() error {
if c.conn != nil {
conn := c.conn
c.conn = nil
return conn.Close()

// Close the connection
err := conn.Close()
if err != nil {
log.Printf("[ERROR] Error closing connection: %s", err)
return err
}

// Ensure the proxy command is terminated if this is a proxy command connection
if proxyConn, ok := conn.(*proxyCommandConn); ok {
proxyConn.Close()
}
}

return nil
Expand Down Expand Up @@ -603,7 +620,7 @@ func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Re

if err != nil {
if exitErr, ok := err.(*ssh.ExitError); ok {
// Otherwise, we have an ExitErorr, meaning we can just read
// Otherwise, we have an ExitError, meaning we can just read
// the exit status
log.Printf("[ERROR] %s", exitErr)

Expand Down Expand Up @@ -897,3 +914,140 @@ func quoteShell(args []string, targetPlatform string) (string, error) {
return "", fmt.Errorf("Cannot quote shell command, target platform unknown: %s", targetPlatform)

}

// ProxyCommandConnectFunc is a convenience method for returning a function
// that connects to a host using a proxy command.
func ProxyCommandConnectFunc(proxyCommand, addr string) func() (net.Conn, error) {
return func() (net.Conn, error) {
log.Printf("[DEBUG] Connecting to %s using proxy command: %s", addr, proxyCommand)

// Split the command into command and args
cmdParts := strings.Fields(proxyCommand)
if len(cmdParts) == 0 {
return nil, fmt.Errorf("Invalid proxy command: %s", proxyCommand)
}

// Create a buffer to capture stderr
stderrBuf := new(bytes.Buffer)

// Create context for command lifecycle management
ctx, cancel := context.WithCancel(context.Background())

// Create proxy command with context
cmd := exec.CommandContext(ctx, cmdParts[0], cmdParts[1:]...)
cmd.Stderr = stderrBuf

// Start the command with pipes
stdin, err := cmd.StdinPipe()
if err != nil {
cancel()
return nil, fmt.Errorf("Error creating stdin pipe for proxy command: %s", err)
}

stdout, err := cmd.StdoutPipe()
if err != nil {
stdin.Close()
cancel()
return nil, fmt.Errorf("Error creating stdout pipe for proxy command: %s", err)
}

if err := cmd.Start(); err != nil {
stdin.Close()
cancel()
return nil, fmt.Errorf("Error starting proxy command: %s", err)
}

// Create a wrapper that manages the command and pipes
conn := &proxyCommandConn{
cmd: cmd,
stderr: stderrBuf,
stdinPipe: stdin,
stdoutPipe: stdout,
ctx: ctx,
cancel: cancel,
}

return conn, nil
}
}

type proxyCommandConn struct {
cmd *exec.Cmd
stderr *bytes.Buffer
stdinPipe io.WriteCloser
stdoutPipe io.ReadCloser
closed bool
mutex sync.Mutex
ctx context.Context
cancel context.CancelFunc
}

// Read reads data from the connection (the command's stdout)
func (c *proxyCommandConn) Read(b []byte) (int, error) {
c.mutex.Lock()
if c.closed {
c.mutex.Unlock()
return 0, io.EOF
}
c.mutex.Unlock()

return c.stdoutPipe.Read(b)
}

// Write writes data to the connection (the command's stdin)
func (c *proxyCommandConn) Write(b []byte) (int, error) {
c.mutex.Lock()
if c.closed {
c.mutex.Unlock()
return 0, io.ErrClosedPipe
}
c.mutex.Unlock()

return c.stdinPipe.Write(b)
}

func (c *proxyCommandConn) Close() error {
c.mutex.Lock()
defer c.mutex.Unlock()

if c.closed {
return nil
}
c.closed = true

log.Print("[DEBUG] Closing proxy command connection")
if err := c.stdinPipe.Close(); err != nil {
log.Printf("[ERROR] Error closing stdin pipe: %s", err)
}

// Cancel the context to signal the command to terminate
c.cancel()

// If the command failed, log the stderr output
if c.stderr.Len() > 0 {
log.Printf("[ERROR] Proxy command stderr: %s", c.stderr.String())
}

return nil
}

// Required methods to implement net.Conn interface
func (c *proxyCommandConn) LocalAddr() net.Addr {
return &net.UnixAddr{Name: "local", Net: "unix"}
}

func (c *proxyCommandConn) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "remote", Net: "unix"}
}

func (c *proxyCommandConn) SetDeadline(t time.Time) error {
return nil // Not implemented
}

func (c *proxyCommandConn) SetReadDeadline(t time.Time) error {
return nil // Not implemented
}

func (c *proxyCommandConn) SetWriteDeadline(t time.Time) error {
return nil // Not implemented
}
74 changes: 50 additions & 24 deletions internal/communicator/ssh/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type connectionInfo struct {
ProxyPort uint16
ProxyUserName string
ProxyUserPassword string
ProxyCommand string

BastionUser string
BastionPassword string
Expand Down Expand Up @@ -151,6 +152,8 @@ func decodeConnInfo(v cty.Value) (*connectionInfo, error) {
}
case "agent_identity":
connInfo.AgentIdentity = v.AsString()
case "proxy_command":
connInfo.ProxyCommand = v.AsString()
}
}
return connInfo, nil
Expand Down Expand Up @@ -275,37 +278,60 @@ func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) {
return nil, err
}

var p *proxyInfo
var connectFunc func() (net.Conn, error)

// Check for conflicting connection methods
connectionMethodCount := 0
if connInfo.ProxyCommand != "" {
connectionMethodCount++
}
if connInfo.ProxyHost != "" {
p = newProxyInfo(
fmt.Sprintf("%s:%d", connInfo.ProxyHost, connInfo.ProxyPort),
connInfo.ProxyScheme,
connInfo.ProxyUserName,
connInfo.ProxyUserPassword,
)
connectionMethodCount++
}
if connInfo.BastionHost != "" {
connectionMethodCount++
}

connectFunc := ConnectFunc("tcp", host, p)
if connectionMethodCount > 1 {
log.Printf("[WARN] Multiple connection methods specified (proxy_command, proxy_host, bastion_host). Using proxy_command if specified, then proxy_host, then bastion_host.")
}

var bastionConf *ssh.ClientConfig
if connInfo.BastionHost != "" {
bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort)

bastionConf, err = buildSSHClientConfig(sshClientConfigOpts{
user: connInfo.BastionUser,
host: bastionHost,
privateKey: connInfo.BastionPrivateKey,
password: connInfo.BastionPassword,
hostKey: connInfo.HostKey,
certificate: connInfo.BastionCertificate,
sshAgent: sshAgent,
})
if err != nil {
return nil, err
// If a proxy command is specified, use it
if connInfo.ProxyCommand != "" {
connectFunc = ProxyCommandConnectFunc(connInfo.ProxyCommand, host)
} else {
// Otherwise, use the standard connection methods
var p *proxyInfo

if connInfo.ProxyHost != "" {
p = newProxyInfo(
fmt.Sprintf("%s:%d", connInfo.ProxyHost, connInfo.ProxyPort),
connInfo.ProxyScheme,
connInfo.ProxyUserName,
connInfo.ProxyUserPassword,
)
}

connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host, p)
connectFunc = ConnectFunc("tcp", host, p)

if connInfo.BastionHost != "" {
bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort)

bastionConf, err := buildSSHClientConfig(sshClientConfigOpts{
user: connInfo.BastionUser,
host: bastionHost,
privateKey: connInfo.BastionPrivateKey,
password: connInfo.BastionPassword,
hostKey: connInfo.HostKey,
certificate: connInfo.BastionCertificate,
sshAgent: sshAgent,
})
if err != nil {
return nil, err
}

connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host, p)
}
}

config := &sshConfig{
Expand Down
57 changes: 57 additions & 0 deletions internal/communicator/ssh/provisioner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,32 @@ func TestProvisioner_connInfoProxy(t *testing.T) {
}
}

func TestProvisioner_connInfoProxyCommand(t *testing.T) {
v := cty.ObjectVal(map[string]cty.Value{
"type": cty.StringVal("ssh"),
"user": cty.StringVal("root"),
"password": cty.StringVal("supersecret"),
"private_key": cty.StringVal("someprivatekeycontents"),
"host": cty.StringVal("example.com"),
"port": cty.StringVal("22"),
"timeout": cty.StringVal("30s"),
"proxy_command": cty.StringVal("ssh -W example.com:2222 bastion-host.example.com"),
})

conf, err := parseConnectionInfo(v)
if err != nil {
t.Fatalf("err: %v", err)
}

if conf.Host != "example.com" {
t.Fatalf("bad: %v", conf)
}

if conf.ProxyCommand != "ssh -W example.com:2222 bastion-host.example.com" {
t.Fatalf("bad: %v", conf)
}
}

func TestProvisioner_stringBastionPort(t *testing.T) {
v := cty.ObjectVal(map[string]cty.Value{
"type": cty.StringVal("ssh"),
Expand Down Expand Up @@ -227,3 +253,34 @@ func TestProvisioner_invalidPortNumber(t *testing.T) {
t.Errorf("unexpected error\n got: %s\nwant: %s", got, want)
}
}

func TestProvisioner_multipleConnectionMethods(t *testing.T) {
// Create a connection info with all three connection methods specified
v := cty.ObjectVal(map[string]cty.Value{
"type": cty.StringVal("ssh"),
"user": cty.StringVal("root"),
"password": cty.StringVal("supersecret"),
"private_key": cty.StringVal("someprivatekeycontents"),
"host": cty.StringVal("example.com"),
"port": cty.StringVal("22"),
"proxy_command": cty.StringVal("ssh -W example.com:2222 bastion-host.example.com"),
"proxy_host": cty.StringVal("proxy.example.com"),
"bastion_host": cty.StringVal("bastion.example.com"),
})

conf, err := parseConnectionInfo(v)
if err != nil {
t.Fatalf("err: %v", err)
}

// Verify all three connection methods are parsed correctly
if conf.ProxyCommand != "ssh -W example.com:2222 bastion-host.example.com" {
t.Fatalf("bad proxy_command: %v", conf.ProxyCommand)
}
if conf.ProxyHost != "proxy.example.com" {
t.Fatalf("bad proxy_host: %v", conf.ProxyHost)
}
if conf.BastionHost != "bastion.example.com" {
t.Fatalf("bad bastion_host: %v", conf.BastionHost)
}
}
4 changes: 4 additions & 0 deletions internal/terraform/node_resource_validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ var connectionBlockSupersetSchema = &configschema.Block{
Type: cty.String,
Optional: true,
},
"proxy_command": {
Type: cty.String,
Optional: true,
},
"bastion_host": {
Type: cty.String,
Optional: true,
Expand Down
Loading
Loading