diff --git a/README.md b/README.md index f11dcfc..106fdd4 100644 --- a/README.md +++ b/README.md @@ -79,37 +79,210 @@ Sample configuration that exposes: looks like this ```yaml - server_addr: SERVER_IP:5223 - tunnels: - webui: - proto: http - addr: localhost:8080 - auth: user:password - host: webui.my-tunnel-host.com - ssh: - proto: tcp - addr: 192.168.0.5:22 - remote_addr: 0.0.0.0:22 +server_addr: SERVER_IP:5223 +tunnels: + webui: + proto: http + local_addr: localhost:8080 + auth: user:password + host: webui.my-tunnel-host.com + ssh: + proto: tcp + local_addr: 192.168.0.5:22 + remote_addr: 0.0.0.0:22 ``` Configuration options: +* `name`: set the client name (not required). +* `description`: set the client name (not required). * `server_addr`: server TCP address, i.e. `54.12.12.45:5223` * `tls_crt`: path to client TLS certificate, *default:* `client.crt` *in the config file directory* * `tls_key`: path to client TLS certificate key, *default:* `client.key` *in the config file directory* * `root_ca`: path to trusted root certificate authority pool file, if empty any server certificate is accepted * `tunnels / [name]` * `proto`: tunnel protocol, `http` or `tcp` - * `addr`: forward traffic to this local port number or network address, for `proto=http` this can be full URL i.e. `https://machine/sub/path/?plus=params`, supports URL schemes `http` and `https` + * `local_addr`: forward traffic to this local port number or network address, for `proto=http` this can be full URL i.e. `https://machine/sub/path/?plus=params`, supports URL schemes `http` and `https` * `auth`: (`proto=http`) (optional) basic authentication credentials to enforce on tunneled requests, format `user:password` * `host`: (`proto=http`) hostname to request (requires reserved name and DNS CNAME) * `remote_addr`: (`proto=tcp`) bind the remote TCP address + * `disabled`: if set to `true` this tunnel has be ignored. * `backoff` * `interval`: how long client would wait before redialing the server if connection was lost, exponential backoff initial interval, *default:* `500ms` * `multiplier`: interval multiplier if reconnect failed, *default:* `1.5` * `max_interval`: maximal time client would wait before redialing the server, *default:* `1m` * `max_time`: maximal time client would try to reconnect to the server if connection was lost, set `0` to never stop trying, *default:* `15m` +## Client Registration + +### Client + +Configuration file `tunnel.yaml`: + +```yaml +registered: true +server_addr: SERVER_IP:5223 +tunnels: + webui: + proto: http + local_addr: localhost:8080 + auth: user:password + ssh: + proto: tcp + local_addr: 192.168.0.5:22 +``` + +Configuration options: + +* `registered`: registered client. Set as `true`. +* `server_addr`: server TCP address, i.e. `54.12.12.45:5223` +* `tls_crt`: path to client TLS certificate, *default:* `client.crt` *in the config file directory* +* `tls_key`: path to client TLS certificate key, *default:* `client.key` *in the config file directory* +* `root_ca`: path to trusted root certificate authority pool file, if empty any server certificate is accepted +* `tunnels / [name]` + * `proto`: tunnel protocol, `http` or `tcp` + * `local_addr`: forward traffic to this local port number or network address, for `proto=http` this can be full URL i.e. `https://machine/sub/path/?plus=params`, supports URL schemes `http` and `https` + * `auth`: (`proto=http`) (optional) basic authentication credentials to enforce on tunneled requests, format `user:password` + * `disabled`: if set to `true` this tunnel has be ignored. +* `backoff` + * `interval`: how long client would wait before redialing the server if connection was lost, exponential backoff initial interval, *default:* `500ms` + * `multiplier`: interval multiplier if reconnect failed, *default:* `1.5` + * `max_interval`: maximal time client would wait before redialing the server, *default:* `1m` + * `max_time`: maximal time client would try to reconnect to the server if connection was lost, set `0` to never stop trying, *default:* `15m` + +### Server + +Create registered clients database structure if not exists: + +```bash +$ mkdir clients +``` + +On client, get ID: `$ tunnel id` + +Registered **Client DIR** `clients/CLIENT_ID`. + +Configure client (file `config.yaml`). +Example `clients/SS2KPSV-5KG2URM-WYUEDLY-FBDAD7A-MUUVTDX-TL7KL45-2PQEQAD-IN4LVAH/config.yaml`: + +```yaml +connections: 4 +tunnels: + webui: + proto: http + host: mps.dea.ufv.br + ssh: + proto: tcp + remote_addr: 127.0.0.1:2222 +``` + +Configuration options: +* `disabled`: if set to `true` block this client. Or create empty file `disabled` on client dir. +Example: `$ touch clients/CLIENT_ID/disabled`. +* `connections`: number of connections in addition to the main connection. Default is `0`. +* `tunnels / [name]` + * `proto`: tunnel protocol, `http` or `tcp` + * `host`: (`proto=http`) hostname to request (requires reserved name and DNS CNAME) + * `remote_addr`: (`proto=tcp`) bind the remote TCP address + * `disabled`: if set to `true` this tunnel has be ignored. + +### Run + +Server: +```bash +$ tunneld -clientsDir ./clients + +# or (using default clients dir, only if dir exists) + +$ tunneld +``` + +Client: +```bash +$ tunnel start-all +``` + +#### With Supervisor + +Server example: + + [program:tunneld] + directory=/etc/tunneld + command=/usr/bin/tunneld + autostart=true + autorestart=true + stdout_logfile=/var/log/tunneld.log + stdout_logfile_maxbytes=5MB + stdout_logfile_backups=2 + redirect_stderr=true + user = root + environment = HOME="/root", USER="root" + +Client example: + + [program:tunnel] + directory=/etc/tunnel + command=/usr/bin/tunnel + autostart=true + autorestart=true + startsecs=10 + stdout_logfile=/var/log/tunnel.log + stdout_logfile_maxbytes=5MB + stdout_logfile_backups=2 + redirect_stderr=true + user = root + environment = HOME="/root", USER="root" + +### Embeded + +#### tunnel +Example for embed `tunnel` in your code. + +```go +package main + +import "github.com/mmatczuk/go-http-tunnel/cli/tunnel" + +func tunnelClient() { + options, err := tunnel.ParseArgs(false, "mycmd", "start-all") + if err != nil { + panic(err) + } + tunnel.MainConfigOptions(&tunnel.ClientConfig{ + ServerAddr: "domain.com:2000", + Tunnels: map[string]*tunnel.Tunnel{ + "main": { + Protocol: "tcp", + LocalAddr: "localhost:5000", + RemoteAddr: "domain.com:5000", + }, + }, + }, options) +} + +func main() { + go tunnelClient() + + // ... my system code +} +``` + +#### tunneld + +Example for embed `tunneld` in your code. + +```go +package main + +import "github.com/mmatczuk/go-http-tunnel/cli/tunneld" + +func main() { + go tunneld.MainArgs("mycmd", "-log-level", "3", "-httpAddr", ":19000", "-httpsAddr", ":19001") + + // ... my system code +} +``` + ## How it works A client opens TLS connection to a server. The server accepts connections from known clients only. The client is recognized by its TLS certificate ID. The server is publicly available and proxies incoming connections to the client. Then the connection is further proxied in the client's network. diff --git a/cmd/tunnel/config.go b/cli/tunnel/config.go similarity index 65% rename from cmd/tunnel/config.go rename to cli/tunnel/config.go index da20ffd..e710674 100644 --- a/cmd/tunnel/config.go +++ b/cli/tunnel/config.go @@ -2,11 +2,12 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "fmt" "io/ioutil" + "os" "path/filepath" "time" @@ -21,6 +22,7 @@ const ( DefaultBackoffMultiplier = 1.5 DefaultBackoffMaxInterval = 60 * time.Second DefaultBackoffMaxTime = 15 * time.Minute + ConfigFileSTDIN = "-" ) // BackoffConfig defines behavior of staggering reconnection retries. @@ -34,7 +36,7 @@ type BackoffConfig struct { // Tunnel defines a tunnel. type Tunnel struct { Protocol string `yaml:"proto,omitempty"` - Addr string `yaml:"addr,omitempty"` + LocalAddr string `yaml:"local_addr,omitempty"` Auth string `yaml:"auth,omitempty"` Host string `yaml:"host,omitempty"` RemoteAddr string `yaml:"remote_addr,omitempty"` @@ -42,6 +44,7 @@ type Tunnel struct { // ClientConfig is a tunnel client configuration. type ClientConfig struct { + Registered bool `yaml:"registered"` ServerAddr string `yaml:"server_addr"` TLSCrt string `yaml:"tls_crt"` TLSKey string `yaml:"tls_key"` @@ -50,15 +53,30 @@ type ClientConfig struct { Tunnels map[string]*Tunnel `yaml:"tunnels"` } -func loadClientConfigFromFile(file string) (*ClientConfig, error) { - buf, err := ioutil.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("failed to read file %q: %s", file, err) +func LoadClientConfigFromFile(file string) (*ClientConfig, error) { + var ( + buf []byte + err error + ) + + var configDir string + + if file == ConfigFileSTDIN { + if buf, err = ioutil.ReadAll(os.Stdin); err != nil { + return nil, fmt.Errorf("failed to read config from STDIN: ", err) + } + file = "STDIN" + configDir = "." + } else { + configDir = filepath.Dir(file) + if buf, err = ioutil.ReadFile(file); err != nil { + return nil, fmt.Errorf("failed to read file %q: %s", file, err) + } } c := ClientConfig{ - TLSCrt: filepath.Join(filepath.Dir(file), "client.crt"), - TLSKey: filepath.Join(filepath.Dir(file), "client.key"), + TLSCrt: "client.crt", + TLSKey: "client.key", Backoff: BackoffConfig{ Interval: DefaultBackoffInterval, Multiplier: DefaultBackoffMultiplier, @@ -71,6 +89,14 @@ func loadClientConfigFromFile(file string) (*ClientConfig, error) { return nil, fmt.Errorf("failed to parse file %q: %s", file, err) } + if filepath.Dir(c.TLSCrt) == "." { + c.TLSCrt = filepath.Join(configDir, c.TLSCrt) + } + + if filepath.Dir(c.TLSKey) == "." { + c.TLSKey = filepath.Join(configDir, c.TLSKey) + } + if c.ServerAddr == "" { return nil, fmt.Errorf("server_addr: missing") } @@ -81,11 +107,11 @@ func loadClientConfigFromFile(file string) (*ClientConfig, error) { for name, t := range c.Tunnels { switch t.Protocol { case proto.HTTP: - if err := validateHTTP(t); err != nil { + if err := validateHTTP(c.Registered, t); err != nil { return nil, fmt.Errorf("%s %s", name, err) } case proto.TCP, proto.TCP4, proto.TCP6: - if err := validateTCP(t); err != nil { + if err := validateTCP(c.Registered, t); err != nil { return nil, fmt.Errorf("%s %s", name, err) } default: @@ -96,15 +122,15 @@ func loadClientConfigFromFile(file string) (*ClientConfig, error) { return &c, nil } -func validateHTTP(t *Tunnel) error { +func validateHTTP(registered bool, t *Tunnel) error { var err error - if t.Host == "" { + if !registered && t.Host == "" { return fmt.Errorf("host: missing") } - if t.Addr == "" { + if t.LocalAddr == "" { return fmt.Errorf("addr: missing") } - if t.Addr, err = normalizeURL(t.Addr); err != nil { + if t.LocalAddr, err = normalizeURL(t.LocalAddr); err != nil { return fmt.Errorf("addr: %s", err) } @@ -117,15 +143,17 @@ func validateHTTP(t *Tunnel) error { return nil } -func validateTCP(t *Tunnel) error { +func validateTCP(registered bool, t *Tunnel) error { var err error - if t.RemoteAddr, err = normalizeAddress(t.RemoteAddr); err != nil { - return fmt.Errorf("remote_addr: %s", err) + if !registered { + if t.RemoteAddr, err = normalizeAddress(t.RemoteAddr); err != nil { + return fmt.Errorf("remote_addr: %s", err) + } } - if t.Addr == "" { + if t.LocalAddr == "" { return fmt.Errorf("addr: missing") } - if t.Addr, err = normalizeAddress(t.Addr); err != nil { + if t.LocalAddr, err = normalizeAddress(t.LocalAddr); err != nil { return fmt.Errorf("addr: %s", err) } diff --git a/cmd/tunnel/normalize.go b/cli/tunnel/normalize.go similarity index 98% rename from cmd/tunnel/normalize.go rename to cli/tunnel/normalize.go index 2f9f84d..d205b6d 100644 --- a/cmd/tunnel/normalize.go +++ b/cli/tunnel/normalize.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "fmt" diff --git a/cmd/tunnel/normalize_test.go b/cli/tunnel/normalize_test.go similarity index 99% rename from cmd/tunnel/normalize_test.go rename to cli/tunnel/normalize_test.go index 5db7eb1..ef65c2d 100644 --- a/cmd/tunnel/normalize_test.go +++ b/cli/tunnel/normalize_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "strings" diff --git a/cmd/tunnel/options.go b/cli/tunnel/options.go similarity index 60% rename from cmd/tunnel/options.go rename to cli/tunnel/options.go index ef3176f..296d75d 100644 --- a/cmd/tunnel/options.go +++ b/cli/tunnel/options.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "flag" @@ -26,19 +26,31 @@ Examples: tunnel -config config.yaml -log-level 2 start ssh tunnel start-all -config.yaml: +Normal Client config: +client.yaml: server_addr: SERVER_IP:5223 tunnels: webui: proto: http - addr: localhost:8080 + local_addr: localhost:8080 auth: user:password host: webui.my-tunnel-host.com ssh: proto: tcp - addr: 192.168.0.5:22 + local_addr: 192.168.0.5:22 remote_addr: 0.0.0.0:22 +Registered Client config: +registered_client.yaml: + registered: true + server_addr: SERVER_IP:5223 + tunnels: + webui: + local_addr: localhost:8080 + auth: user:password + ssh: + local_addr: 192.168.0.5:22 + Author: Written by M. Matczuk (mmatczuk@gmail.com) @@ -62,17 +74,40 @@ type options struct { args []string } -func parseArgs() (*options, error) { - config := flag.String("config", "tunnel.yml", "Path to tunnel configuration file") - logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3") - version := flag.Bool("version", false, "Prints tunnel version") - flag.Parse() +func (opt options) Command() string { + return opt.command +} + +func (opt options) LogLevel() int { + return opt.logLevel +} + +func (opt options) Args() []string { + return opt.args +} + +func ParseArgs(hasConfig bool, args ...string) (opts *options, err error) { + var config *string + cli := flag.NewFlagSet(args[0], flag.ExitOnError) + args = args[1:] + if hasConfig { + config = cli.String("config", "tunnel.yaml", + "Path to tunnel configuration file. Use HYPHEN (-) for read STDIN.") + } else { + var s = "" + config = &s + } + logLevel := cli.Int("log-level", 1, "Level of messages to log, 0-3") + version := cli.Bool("version", false, "Prints tunnel version") + if err = cli.Parse(args); err != nil { + return nil, err + } - opts := &options{ + opts = &options{ config: *config, logLevel: *logLevel, version: *version, - command: flag.Arg(0), + command: cli.Arg(0), } if opts.version { @@ -81,20 +116,18 @@ func parseArgs() (*options, error) { switch opts.command { case "": - flag.Usage() - os.Exit(2) case "id", "list": - opts.args = flag.Args()[1:] + opts.args = cli.Args()[1:] if len(opts.args) > 0 { return nil, fmt.Errorf("list takes no arguments") } case "start": - opts.args = flag.Args()[1:] + opts.args = cli.Args()[1:] if len(opts.args) == 0 { return nil, fmt.Errorf("you must specify at least one tunnel to start") } case "start-all": - opts.args = flag.Args()[1:] + opts.args = cli.Args()[1:] if len(opts.args) > 0 { return nil, fmt.Errorf("start-all takes no arguments") } diff --git a/cmd/tunnel/tunnel.go b/cli/tunnel/tunnel.go similarity index 82% rename from cmd/tunnel/tunnel.go rename to cli/tunnel/tunnel.go index 9bce9ea..99f4b97 100644 --- a/cmd/tunnel/tunnel.go +++ b/cli/tunnel/tunnel.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "crypto/tls" @@ -23,26 +23,41 @@ import ( "github.com/mmatczuk/go-http-tunnel/proto" ) -func main() { - opts, err := parseArgs() +func Main() { + MainArgs(os.Args...) +} + +func MainArgs(args ...string) { + opts, err := ParseArgs(true, args...) if err != nil { fatal(err.Error()) } + MainOpts(opts) +} +func MainOpts(opts *options) { if opts.version { fmt.Println(version) return } - logger := log.NewFilterLogger(log.NewStdLogger(), opts.logLevel) - // read configuration file - config, err := loadClientConfigFromFile(opts.config) + config, err := LoadClientConfigFromFile(opts.config) if err != nil { fatal("configuration error: %s", err) } - switch opts.command { + MainConfigOptions(config, opts) +} + +func MainConfigOptions(config *ClientConfig, options *options) { + MainConfig(config, options.logLevel, options.command, options.args...) +} + +func MainConfig(config *ClientConfig, logLevel int, command string, args ...string) { + logger := log.NewFilterLogger(log.NewStdLogger(), logLevel) + + switch command { case "id": cert, err := tls.LoadX509KeyPair(config.TLSCrt, config.TLSKey) if err != nil { @@ -70,7 +85,7 @@ func main() { return case "start": tunnels := make(map[string]*Tunnel) - for _, arg := range opts.args { + for _, arg := range args { t, ok := config.Tunnels[arg] if !ok { fatal("no such tunnel %q", arg) @@ -102,6 +117,7 @@ func main() { Tunnels: tunnels(config.Tunnels), Proxy: proxy(config.Tunnels, logger), Logger: logger, + Registered: config.Registered, }) if err != nil { fatal("failed to create client: %s", err) @@ -158,10 +174,11 @@ func tunnels(m map[string]*Tunnel) map[string]*proto.Tunnel { for name, t := range m { p[name] = &proto.Tunnel{ - Protocol: t.Protocol, - Host: t.Host, - Auth: t.Auth, - Addr: t.RemoteAddr, + Protocol: t.Protocol, + Host: t.Host, + Auth: t.Auth, + LocalAddr: t.LocalAddr, + RemoteAddr: t.RemoteAddr, } } @@ -175,13 +192,13 @@ func proxy(m map[string]*Tunnel, logger log.Logger) tunnel.ProxyFunc { for _, t := range m { switch t.Protocol { case proto.HTTP: - u, err := url.Parse(t.Addr) + u, err := url.Parse(t.LocalAddr) if err != nil { fatal("invalid tunnel address: %s", err) } httpURL[t.Host] = u case proto.TCP, proto.TCP4, proto.TCP6: - tcpAddr[t.RemoteAddr] = t.Addr + tcpAddr[t.LocalAddr] = t.LocalAddr } } diff --git a/cmd/tunnel/version.go b/cli/tunnel/version.go similarity index 65% rename from cmd/tunnel/version.go rename to cli/tunnel/version.go index 189c430..28756d1 100644 --- a/cmd/tunnel/version.go +++ b/cli/tunnel/version.go @@ -2,6 +2,11 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel var version = "snapshot" + +// Return the tunnel version +func GetVersion() string { + return version +} diff --git a/cmd/tunneld/banner.go b/cli/tunneld/banner.go similarity index 97% rename from cmd/tunneld/banner.go rename to cli/tunneld/banner.go index c5cf848..d4f88c1 100644 --- a/cmd/tunneld/banner.go +++ b/cli/tunneld/banner.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunneld const banner = ` ______ __ __________________ __ __ diff --git a/cmd/tunneld/options.go b/cli/tunneld/options.go similarity index 87% rename from cmd/tunneld/options.go rename to cli/tunneld/options.go index 06994dc..353f7ec 100644 --- a/cmd/tunneld/options.go +++ b/cli/tunneld/options.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunneld import ( "flag" @@ -44,11 +44,14 @@ type options struct { tlsKey string rootCA string clients string + clientsDir string logLevel int version bool } -func parseArgs() *options { +func ParseArgs(args ...string) *options { + flag := flag.NewFlagSet(args[0], flag.ExitOnError) + args = args[1:] httpAddr := flag.String("httpAddr", ":80", "Public address for HTTP connections, empty string to disable") httpsAddr := flag.String("httpsAddr", ":443", "Public address listening for HTTPS connections, emptry string to disable") tunnelAddr := flag.String("tunnelAddr", ":5223", "Public address listening for tunnel client") @@ -56,9 +59,10 @@ func parseArgs() *options { tlsKey := flag.String("tlsKey", "server.key", "Path to a TLS key file") rootCA := flag.String("rootCA", "", "Path to the trusted certificate chian used for client certificate authentication, if empty any client certificate is accepted") clients := flag.String("clients", "", "Comma-separated list of tunnel client ids, if empty accept all clients") + clientsDir := flag.String("clientsDir", "clients", "Database directory of registered clients") logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3") version := flag.Bool("version", false, "Prints tunneld version") - flag.Parse() + flag.Parse(args) return &options{ httpAddr: *httpAddr, @@ -68,6 +72,7 @@ func parseArgs() *options { tlsKey: *tlsKey, rootCA: *rootCA, clients: *clients, + clientsDir: *clientsDir, logLevel: *logLevel, version: *version, } diff --git a/cmd/tunneld/tunneld.go b/cli/tunneld/tunneld.go similarity index 84% rename from cmd/tunneld/tunneld.go rename to cli/tunneld/tunneld.go index d7c1eba..fc1daeb 100644 --- a/cmd/tunneld/tunneld.go +++ b/cli/tunneld/tunneld.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunneld import ( "crypto/tls" @@ -20,8 +20,12 @@ import ( "github.com/mmatczuk/go-http-tunnel/log" ) -func main() { - opts := parseArgs() +func Main() { + MainArgs(os.Args...) +} + +func MainArgs(args ...string) { + opts := ParseArgs(args...) if opts.version { fmt.Println(version) @@ -37,7 +41,16 @@ func main() { fatal("failed to configure tls: %s", err) } - autoSubscribe := opts.clients == "" + var clientsProvider tunnel.RegisteredClientsProvider + if opts.clientsDir != "" { + if _, err := os.Stat(opts.clientsDir); err == nil { + clientsProvider = &tunnel.RegisteredClientsFileSystemProvider{opts.clientsDir} + } else { + opts.clientsDir = "" + } + } + + autoSubscribe := opts.clients == "" && opts.clientsDir == "" // setup server server, err := tunnel.NewServer(&tunnel.ServerConfig{ @@ -45,12 +58,13 @@ func main() { AutoSubscribe: autoSubscribe, TLSConfig: tlsconf, Logger: logger, + RegisteredClientsProvider: clientsProvider, }) if err != nil { fatal("failed to create server: %s", err) } - if !autoSubscribe { + if !autoSubscribe && clientsProvider == nil { for _, c := range strings.Split(opts.clients, ",") { if c == "" { fatal("empty client id") diff --git a/cmd/tunneld/version.go b/cli/tunneld/version.go similarity index 91% rename from cmd/tunneld/version.go rename to cli/tunneld/version.go index 189c430..696adfe 100644 --- a/cmd/tunneld/version.go +++ b/cli/tunneld/version.go @@ -2,6 +2,6 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunneld var version = "snapshot" diff --git a/client.go b/client.go index 3647ad7..d2e206f 100644 --- a/client.go +++ b/client.go @@ -11,11 +11,14 @@ import ( "fmt" "net" "net/http" + "os" "sync" "time" "golang.org/x/net/http2" + "strconv" + "github.com/mmatczuk/go-http-tunnel/log" "github.com/mmatczuk/go-http-tunnel/proto" ) @@ -33,27 +36,41 @@ type ClientConfig struct { // Backoff specifies backoff policy on server connection retry. If nil // when dial fails it will not be retried. Backoff Backoff - // Tunnels specifies the tunnels client requests to be opened on server. + // Tunnels specifies the tunnels controller requests to be opened on server. Tunnels map[string]*proto.Tunnel // Proxy is ProxyFunc responsible for transferring data between server // and local services. Proxy ProxyFunc // Logger is optional logger. If nil logging is disabled. Logger log.Logger + // Name the controller name. + Name string + // Description Client Description. + Description string + // Registered if client is registered on server. + Registered bool } -// Client is responsible for creating connection to the server, handling control -// messages. It uses ProxyFunc for transferring data between server and local -// services. -type Client struct { - config *ClientConfig - +type ClientToServerConn struct { + c *Client conn net.Conn connMu sync.Mutex httpServer *http2.Server serverErr error lastDisconnect time.Time logger log.Logger + handler http.HandlerFunc +} + +// Client is responsible for creating connection to the server, handling control +// messages. It uses ProxyFunc for transferring data between server and local +// services. +type Client struct { + *ClientToServerConn + config *ClientConfig + connectionCount int + connections []*ClientToServerConn + hostname string } // NewClient creates a new unconnected Client based on configuration. Caller @@ -78,19 +95,40 @@ func NewClient(config *ClientConfig) (*Client, error) { } c := &Client{ - config: config, + config: config, + } + + c.hostname, _ = os.Hostname() + + c.ClientToServerConn = &ClientToServerConn{ + c: c, httpServer: &http2.Server{}, logger: logger, - } + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + if r.Header.Get(proto.HeaderError) != "" { + c.handleHandshakeError(w, r) + } else { + c.handleHandshake(w, r) + } + return + } + c.ClientToServerConn.serveHTTP(w, r) + }} return c, nil } -// Start connects client to the server, it returns error if there is a +// Hostname returns the client hostname +func (c *Client) Hostname() string { + return c.hostname +} + +// Start connects controller to the server, it returns error if there is a // connection error, or server cannot open requested tunnels. On connection // error a backoff policy is used to reestablish the connection. When connected // HTTP/2 server is started to handle ControlMessages. -func (c *Client) Start() error { +func (c *ClientToServerConn) Start() error { c.logger.Log( "level", 1, "action", "start", @@ -102,8 +140,14 @@ func (c *Client) Start() error { return err } + handler := c.handler + + if handler == nil { + handler = http.HandlerFunc(c.serveHTTP) + } + c.httpServer.ServeConn(conn, &http2.ServeConnOpts{ - Handler: http.HandlerFunc(c.serveHTTP), + Handler: handler, }) c.logger.Log( @@ -131,7 +175,7 @@ func (c *Client) Start() error { } } -func (c *Client) connect() (net.Conn, error) { +func (c *ClientToServerConn) connect() (net.Conn, error) { c.connMu.Lock() defer c.connMu.Unlock() @@ -148,11 +192,11 @@ func (c *Client) connect() (net.Conn, error) { return conn, nil } -func (c *Client) dial() (net.Conn, error) { +func (c *ClientToServerConn) dial() (net.Conn, error) { var ( network = "tcp" - addr = c.config.ServerAddr - tlsConfig = c.config.TLSClientConfig + addr = c.c.config.ServerAddr + tlsConfig = c.c.config.TLSClientConfig ) doDial := func() (conn net.Conn, err error) { @@ -163,8 +207,8 @@ func (c *Client) dial() (net.Conn, error) { "addr", addr, ) - if c.config.DialTLS != nil { - conn, err = c.config.DialTLS(network, addr, tlsConfig) + if c.c.config.DialTLS != nil { + conn, err = c.c.config.DialTLS(network, addr, tlsConfig) } else { d := &net.Dialer{ Timeout: DefaultTimeout, @@ -200,7 +244,7 @@ func (c *Client) dial() (net.Conn, error) { return } - b := c.config.Backoff + b := c.c.config.Backoff if b == nil { return doDial() } @@ -230,16 +274,7 @@ func (c *Client) dial() (net.Conn, error) { } } -func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodConnect { - if r.Header.Get(proto.HeaderError) != "" { - c.handleHandshakeError(w, r) - } else { - c.handleHandshake(w, r) - } - return - } - +func (c *ClientToServerConn) serveHTTP(w http.ResponseWriter, r *http.Request) { msg, err := proto.ReadControlMessage(r) if err != nil { c.logger.Log( @@ -257,7 +292,7 @@ func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) { ) switch msg.Action { case proto.ActionProxy: - c.config.Proxy(w, r.Body, msg) + c.c.config.Proxy(w, r.Body, msg) default: c.logger.Log( "level", 0, @@ -297,7 +332,37 @@ func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - b, err := json.Marshal(c.config.Tunnels) + var info interface{} + + if c.config.Registered { + if cc := r.Header.Get(HeaderConnectionsCount); cc != "" { + c.connectionCount, _ = strconv.Atoi(cc) + } + + tunnels := map[string]*proto.RegisteredTunnel{} + for name, t := range c.config.Tunnels { + tunnels[name] = &proto.RegisteredTunnel{ + Disabled: t.Disabled, + LocalAddr: t.LocalAddr, + Auth: t.Auth, + } + } + + info = &RegisteredClientInfo{ + Hostname: c.hostname, + Tunnels: tunnels, + } + } else { + info = &RegisteredClientConfig{ + Name: c.config.Name, + Description: c.config.Description, + Tunnels: c.config.Tunnels, + Hostname: c.hostname, + } + } + + b, err := json.Marshal(info) + if err != nil { c.logger.Log( "level", 0, @@ -306,11 +371,26 @@ func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) { ) return } - w.Write(b) + + if _, err := w.Write(b); err == nil && c.config.Registered && c.connectionCount > 0 { + if len(c.connections) == 0 { + for i := 0; i < c.connectionCount; i++ { + c.connections = append(c.connections, &ClientToServerConn{ + c: c, + httpServer: c.httpServer, + logger: log.NewContext(c.logger).With(fmt.Sprint("[", i, "]")), + }) + } + + for _, cc := range c.connections { + go cc.Start() + } + } + } } -// Stop disconnects client from server. -func (c *Client) Stop() { +// Stop disconnects controller from server. +func (c *ClientToServerConn) Stop() { c.connMu.Lock() defer c.connMu.Unlock() diff --git a/cmd/tunnel/main.go b/cmd/tunnel/main.go new file mode 100644 index 0000000..1f50381 --- /dev/null +++ b/cmd/tunnel/main.go @@ -0,0 +1,11 @@ +// Copyright (C) 2017 Michał Matczuk +// Use of this source code is governed by an AGPL-style +// license that can be found in the LICENSE file. + +package main + +import "github.com/mmatczuk/go-http-tunnel/cli/tunnel" + +func main() { + tunnel.Main() +} diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go new file mode 100644 index 0000000..0e12e3c --- /dev/null +++ b/cmd/tunneld/main.go @@ -0,0 +1,11 @@ +// Copyright (C) 2017 Michał Matczuk +// Use of this source code is governed by an AGPL-style +// license that can be found in the LICENSE file. + +package main + +import "github.com/mmatczuk/go-http-tunnel/cli/tunneld" + +func main() { + tunneld.Main() +} diff --git a/common.go b/common.go new file mode 100644 index 0000000..900263c --- /dev/null +++ b/common.go @@ -0,0 +1,17 @@ +package tunnel + +import ( + "net" + + "github.com/mmatczuk/go-http-tunnel/id" + "github.com/mmatczuk/go-http-tunnel/log" +) + +const HeaderConnectionsCount = "X-Connections-Count" + +type clientConnectionController struct { + cfg *RegisteredClientConfig + conn net.Conn + ID id.ID + logger log.Logger +} diff --git a/doc.go b/doc.go index 091aae9..279adcb 100644 --- a/doc.go +++ b/doc.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -// Package tunnel is fast and secure client/server package that enables proxying +// Package tunnel is fast and secure controller/server package that enables proxying // public connections to your local machine over a tunnel connection from the // local machine to the public server. package tunnel diff --git a/errors.go b/errors.go index 1a6b53b..ea36386 100644 --- a/errors.go +++ b/errors.go @@ -7,9 +7,10 @@ package tunnel import "errors" var ( - errClientNotSubscribed = errors.New("client not subscribed") - errClientNotConnected = errors.New("client not connected") - errClientAlreadyConnected = errors.New("client already connected") + errClientManyConnections = errors.New("controller has many connections") + errClientNotSubscribed = errors.New("controller not subscribed") + errClientNotConnected = errors.New("controller not connected") + errClientAlreadyConnected = errors.New("controller already connected") errUnauthorised = errors.New("unauthorised") ) diff --git a/integration_test.go b/integration_test.go index 78fc039..38e8ba0 100644 --- a/integration_test.go +++ b/integration_test.go @@ -89,12 +89,13 @@ func makeEcho(t testing.TB) (http net.Listener, tcp net.Listener) { return } -func makeTunnelServer(t testing.TB) *tunnel.Server { +func makeTunnelServer(t testing.TB, clientsProvider tunnel.RegisteredClientsProvider) *tunnel.Server { s, err := tunnel.NewServer(&tunnel.ServerConfig{ Addr: ":0", - AutoSubscribe: true, + AutoSubscribe: clientsProvider == nil, TLSConfig: tlsConfig(), Logger: log.NewStdLogger(), + RegisteredClientsProvider: clientsProvider, }) if err != nil { t.Fatal(err) @@ -118,13 +119,14 @@ func makeTunnelClient(t testing.TB, serverAddr string, httpLocalAddr, httpAddr, tunnels := map[string]*proto.Tunnel{ proto.HTTP: { - Protocol: proto.HTTP, - Host: "localhost", - Auth: "user:password", + Protocol: proto.HTTP, + Host: "localhost", + Auth: "user:password", + LocalAddr: httpLocalAddr.String(), }, proto.TCP: { - Protocol: proto.TCP, - Addr: tcpLocalAddr.String(), + Protocol: proto.TCP, + LocalAddr: tcpLocalAddr.String(), }, } @@ -157,7 +159,7 @@ func TestIntegration(t *testing.T) { defer tcp.Close() // server - s := makeTunnelServer(t) + s := makeTunnelServer(t, nil) defer s.Stop() h := httptest.NewServer(s) defer h.Close() @@ -165,12 +167,12 @@ func TestIntegration(t *testing.T) { httpLocalAddr := h.Listener.Addr() tcpLocalAddr := freeAddr() - // client + // controller c := makeTunnelClient(t, s.Addr(), httpLocalAddr, http.Addr(), tcpLocalAddr, tcp.Addr(), ) - // FIXME: replace sleep with client state change watch when ready + // FIXME: replace sleep with controller state change watch when ready time.Sleep(500 * time.Millisecond) defer c.Stop() @@ -196,7 +198,7 @@ func TestIntegration(t *testing.T) { }() wg.Add(1) go func() { - testTCP(t, tcpLocalAddr, p, r) + testTCP(t, tcpLocalAddr, p, r, false) wg.Done() }() } @@ -204,6 +206,94 @@ func TestIntegration(t *testing.T) { wg.Wait() } +func makeTunnelRegisteredClient(t testing.TB, serverAddr string, tcpLocalAddr net.Addr) *tunnel.Client { + tcpProxy := tunnel.NewTCPProxy(tcpLocalAddr.String(), log.NewStdLogger()) + + tunnels := map[string]*proto.Tunnel{ + proto.TCP: { + Protocol: proto.TCP, + LocalAddr: tcpLocalAddr.String(), + }, + } + + c, err := tunnel.NewClient(&tunnel.ClientConfig{ + ServerAddr: serverAddr, + TLSClientConfig: tlsConfig(), + Tunnels: tunnels, + Registered: true, + Proxy: tunnel.Proxy(tunnel.ProxyFuncs{ + TCP: tcpProxy.Proxy, + }), + Logger: log.NewStdLogger(), + }) + if err != nil { + t.Fatal(err) + } + go func() { + if err := c.Start(); err != nil { + t.Log(err) + } + }() + + return c +} + +func TestRegisteredClientIntegration(t *testing.T) { + // local services + tcp, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + go echoTCP(tcp) + defer tcp.Close() + + tcpRemoteAddr := freeAddr() + numConnections := 3 + + // server + s := makeTunnelServer(t, ®isteredClientsProvider{&tunnel.RegisteredClientConfig{ + Connections: numConnections, + Tunnels: map[string]*proto.Tunnel{ + proto.TCP: { + Protocol: proto.TCP, + RemoteAddr: tcpRemoteAddr.String(), + }, + }, + }}) + defer s.Stop() + + // controller + c := makeTunnelRegisteredClient(t, s.Addr(), tcp.Addr()) + // FIXME: replace sleep with controller state change watch when ready + time.Sleep(500 * time.Millisecond) + defer c.Stop() + + payload := randPayload(payloadInitialSize, payloadLen) + table := []struct { + S []uint + }{ + {[]uint{200, 160, 120, 80, 40, 20}}, + {[]uint{40, 80, 120, 160, 200}}, + {[]uint{0, 0, 0, 0, 0, 200}}, + } + + var wg sync.WaitGroup + for _, test := range table { + for i, repeat := range test.S { + p := payload[i] + r := repeat + for i := 0; i <= numConnections; i++ { + wg.Add(1) + go func() { + testTCP(t, tcpRemoteAddr, p, r, true) + wg.Done() + }() + } + } + } + wg.Wait() +} + func testHTTP(t testing.TB, addr net.Addr, payload []byte, repeat uint) { url := fmt.Sprintf("http://localhost:%s/some/path", port(addr)) @@ -233,13 +323,17 @@ func testHTTP(t testing.TB, addr net.Addr, payload []byte, repeat uint) { } } -func testTCP(t testing.TB, addr net.Addr, payload []byte, repeat uint) { +func testTCP(t testing.TB, addr net.Addr, payload []byte, repeat uint, sleep bool) { conn, err := net.Dial("tcp", addr.String()) if err != nil { t.Fatal("Dial failed", err) } defer conn.Close() + if sleep { + time.Sleep(50 * time.Millisecond) + } + var buf = make([]byte, 10*1024*1024) var read, write int for repeat > 0 { diff --git a/pool.go b/pool.go index 2219b3f..ce50fb7 100644 --- a/pool.go +++ b/pool.go @@ -7,8 +7,10 @@ package tunnel import ( "context" "fmt" + "io" "net" "net/http" + "sort" "sync" "time" @@ -17,14 +19,108 @@ import ( "github.com/mmatczuk/go-http-tunnel/id" ) +const connKey = "conn" + +type clientConnections struct { + first *clientConnection + last *clientConnection + count int + mu sync.Mutex + lastid int + conns []*clientConnection +} + +func (cs *clientConnections) low() *clientConnection { + cs.mu.Lock() + defer cs.mu.Unlock() + sort.Slice(cs.conns, func(i, j int) bool { + return cs.conns[i].count < cs.conns[j].count + }) + return cs.conns[0].increase() +} + +func (cs *clientConnections) add(con *http2.ClientConn) *clientConnection { + cs.mu.Lock() + defer cs.mu.Unlock() + + c := &clientConnection{con, cs.last, nil, cs.lastid, 0} + cs.lastid++ + if cs.last == nil { + cs.first = c + } else { + c.prev.next = c + } + cs.last = c + cs.count++ + cs.conns = append(cs.conns, c) + return cs.last +} + +func (cs *clientConnections) remove(c *clientConnection) { + cs.mu.Lock() + defer cs.mu.Unlock() + + if cs.count == 1 { + cs.first = nil + cs.last = nil + } else if cs.last == c { + cs.last = c.prev + cs.last.next = nil + } else if cs.first == c { + cs.first = c.next + cs.first.prev = nil + } else { + p, n := c.prev, c.next + p.next = n + n.prev = p + } + + c.prev = nil + c.next = nil + cs.count-- +} + +type clientConnectionSetter struct { + conn *clientConnection +} + +type clientConnection struct { + *http2.ClientConn + prev *clientConnection + next *clientConnection + id int + count int +} + +func (c *clientConnection) decrease() *clientConnection { + c.count-- + return c +} + +func (c *clientConnection) increase() *clientConnection { + c.count++ + return c +} + type connPair struct { - conn net.Conn - clientConn *http2.ClientConn + controller *clientConnectionController + clientConn *clientConnection + conns *clientConnections + current *clientConnection + mu sync.Mutex +} + +func (cp *connPair) next() *clientConnection { + cp.mu.Lock() + defer cp.mu.Unlock() + con := cp.conns.low() + cp.current = con + return con } type connPool struct { t *http2.Transport - conns map[string]connPair // key is host:port + conns map[string]*connPair // key is host:port free func(identifier id.ID) mu sync.RWMutex } @@ -33,10 +129,17 @@ func newConnPool(t *http2.Transport, f func(identifier id.ID)) *connPool { return &connPool{ t: t, free: f, - conns: make(map[string]connPair), + conns: make(map[string]*connPair), } } - +func (p *connPool) newRequest(method string, identifier id.ID, body io.Reader) (*http.Request, error) { + req, err := http.NewRequest(method, p.URL(identifier), body) + if err != nil { + return nil, err + } + req = req.WithContext(context.WithValue(req.Context(), connKey, &clientConnectionSetter{})) + return req, err +} func (p *connPool) URL(identifier id.ID) string { return fmt.Sprint("https://", identifier) } @@ -45,8 +148,13 @@ func (p *connPool) GetClientConn(req *http.Request, addr string) (*http2.ClientC p.mu.RLock() defer p.mu.RUnlock() - if cp, ok := p.conns[addr]; ok && cp.clientConn.CanTakeNewRequest() { - return cp.clientConn, nil + if cp, ok := p.conns[addr]; ok { + conn := cp.next() + if setter := req.Context().Value(connKey); setter != nil { + setter.(*clientConnectionSetter).conn = conn + } + cp.controller.logger.Log("level", 3, "client_conn", fmt.Sprintf("#%d", conn.id), "addr", addr) + return conn.ClientConn, nil } return nil, errClientNotConnected @@ -57,18 +165,44 @@ func (p *connPool) MarkDead(c *http2.ClientConn) { defer p.mu.Unlock() for addr, cp := range p.conns { - if cp.clientConn == c { + if cp.clientConn.ClientConn == c { p.close(cp, addr) return } } } -func (p *connPool) AddConn(conn net.Conn, identifier id.ID) error { +func (p *connPool) Has(ID id.ID) (ok bool) { p.mu.Lock() defer p.mu.Unlock() - addr := p.addr(identifier) + addr := p.addr(ID) + _, ok = p.conns[addr] + return +} + +func (p *connPool) AddClientConnection(ID id.ID, conn net.Conn) (ccon *clientConnection, err error) { + c, err := p.t.NewClientConn(conn) + if err != nil { + return nil, err + } + + addr := p.addr(ID) + + cp := p.conns[addr] + if (cp.conns.count - 1) >= cp.controller.cfg.Connections { + return nil, errClientManyConnections + } + p.mu.Lock() + defer p.mu.Unlock() + return cp.conns.add(c), nil +} + +func (p *connPool) AddConn(controller *clientConnectionController) error { + p.mu.Lock() + defer p.mu.Unlock() + + addr := p.addr(controller.ID) if cp, ok := p.conns[addr]; ok { if err := p.ping(cp); err != nil { @@ -78,14 +212,13 @@ func (p *connPool) AddConn(conn net.Conn, identifier id.ID) error { } } - c, err := p.t.NewClientConn(conn) + c, err := p.t.NewClientConn(controller.conn) if err != nil { return err } - p.conns[addr] = connPair{ - conn: conn, - clientConn: c, - } + cp := &connPair{controller: controller, conns: &clientConnections{}} + cp.clientConn = cp.conns.add(c) + p.conns[addr] = cp return nil } @@ -101,6 +234,12 @@ func (p *connPool) DeleteConn(identifier id.ID) { } } +func (p *connPool) closeReqConn(req *http.Request) { + if conn := req.Context().Value(connKey); conn != nil { + conn.(*clientConnectionSetter).conn.decrease() + } +} + func (p *connPool) Ping(identifier id.ID) (time.Duration, error) { p.mu.Lock() defer p.mu.Unlock() @@ -116,15 +255,15 @@ func (p *connPool) Ping(identifier id.ID) (time.Duration, error) { return 0, errClientNotConnected } -func (p *connPool) ping(cp connPair) error { +func (p *connPool) ping(cp *connPair) error { ctx, cancel := context.WithTimeout(context.Background(), DefaultPingTimeout) defer cancel() return cp.clientConn.Ping(ctx) } -func (p *connPool) close(cp connPair, addr string) { - cp.conn.Close() +func (p *connPool) close(cp *connPair, addr string) { + cp.controller.conn.Close() delete(p.conns, addr) if p.free != nil { p.free(p.identifier(addr)) diff --git a/proto/controlmsg.go b/proto/controlmsg.go index 2dddc93..aaef5f1 100644 --- a/proto/controlmsg.go +++ b/proto/controlmsg.go @@ -13,6 +13,7 @@ import ( const ( HeaderError = "X-Error" + HeaderLocalAddr = "X-Local-RemoteAddr" HeaderAction = "X-Action" HeaderForwardedHost = "X-Forwarded-Host" HeaderForwardedProto = "X-Forwarded-Proto" @@ -38,6 +39,7 @@ const ( // used to inform client about the data and action to take. Based on that client // routes requests to backend services. type ControlMessage struct { + LocalAddr string Action string ForwardedHost string ForwardedProto string @@ -47,6 +49,7 @@ type ControlMessage struct { // ReadControlMessage reads ControlMessage from HTTP headers. func ReadControlMessage(r *http.Request) (*ControlMessage, error) { msg := ControlMessage{ + LocalAddr: r.Header.Get(HeaderLocalAddr), Action: r.Header.Get(HeaderAction), ForwardedHost: r.Header.Get(HeaderForwardedHost), ForwardedProto: r.Header.Get(HeaderForwardedProto), @@ -55,6 +58,9 @@ func ReadControlMessage(r *http.Request) (*ControlMessage, error) { var missing []string + if msg.LocalAddr == "" { + missing = append(missing, HeaderLocalAddr) + } if msg.Action == "" { missing = append(missing, HeaderAction) } @@ -74,6 +80,7 @@ func ReadControlMessage(r *http.Request) (*ControlMessage, error) { // WriteToHeader writes ControlMessage to HTTP header. func (c *ControlMessage) WriteToHeader(h http.Header) { + h.Set(HeaderLocalAddr, c.LocalAddr) h.Set(HeaderAction, string(c.Action)) h.Set(HeaderForwardedHost, c.ForwardedHost) h.Set(HeaderForwardedProto, c.ForwardedProto) diff --git a/proto/controlmsg_test.go b/proto/controlmsg_test.go index 51f39e0..564ca39 100644 --- a/proto/controlmsg_test.go +++ b/proto/controlmsg_test.go @@ -20,32 +20,41 @@ func TestControlMessageWriteRead(t *testing.T) { }{ { &ControlMessage{ + LocalAddr: "local_addr", Action: "action", ForwardedHost: "forwarded_host", ForwardedProto: "forwarded_proto", }, nil, }, + { + &ControlMessage{ + Action: "action", + ForwardedHost: "forwarded_host", + ForwardedProto: "forwarded_proto", + }, + errors.New("missing headers: [X-Local-RemoteAddr]"), + }, { &ControlMessage{ ForwardedHost: "forwarded_host", ForwardedProto: "forwarded_proto", }, - errors.New("missing headers: [X-Action]"), + errors.New("missing headers: [X-Local-RemoteAddr X-Action]"), }, { &ControlMessage{ Action: "action", ForwardedHost: "forwarded_host", }, - errors.New("missing headers: [X-Forwarded-Proto]"), + errors.New("missing headers: [X-Local-RemoteAddr X-Forwarded-Proto]"), }, { &ControlMessage{ Action: "action", ForwardedProto: "forwarded_proto", }, - errors.New("missing headers: [X-Forwarded-Host]"), + errors.New("missing headers: [X-Local-RemoteAddr X-Forwarded-Host]"), }, } diff --git a/proto/tunnel.go b/proto/tunnel.go index 2b5e71e..2d4c96a 100644 --- a/proto/tunnel.go +++ b/proto/tunnel.go @@ -6,18 +6,35 @@ package proto // Tunnel describes a single tunnel between client and server. When connecting // client sends tunnels to server. If client gets connected server proxies -// connections to given Host and Addr to the client. +// connections to given Host and RemoteAddr to the client. type Tunnel struct { + // Disabled tunnel disabled + Disabled bool // Protocol specifies tunnel protocol, must be one of protocols known // by the server. - Protocol string + Protocol string `yaml:"proto"` // Host specified HTTP request host, it's required for HTTP and WS // tunnels. Host string // Auth specifies HTTP basic auth credentials in form "user:password", // if set server would protect HTTP and WS tunnels with basic auth. Auth string - // Addr specifies TCP address server would listen on, it's required + // RemoteAddr specifies TCP address server would listen on, it's required // for TCP tunnels. - Addr string + RemoteAddr string `yaml:"remote_addr"` + // LocalAddr client local addr + LocalAddr string `yaml:"local_addr"` +} + +// RegisteredTunnel describes a single registered tunnel between client and server. When connecting +// registered client sends tunnels to server. If client gets connected server proxies +// connections to given Host and LocalAddr to the client. +type RegisteredTunnel struct { + // Disabled tunnel disabled + Disabled bool + // Auth specifies HTTP basic auth credentials in form "user:password", + // if set server would protect HTTP and WS tunnels with basic auth. + Auth string + // LocalAddr client local addr + LocalAddr string `yaml:"local_addr"` } diff --git a/registered_client.go b/registered_client.go new file mode 100644 index 0000000..f28cfd4 --- /dev/null +++ b/registered_client.go @@ -0,0 +1,86 @@ +// Copyright (C) 2017 Michał Matczuk +// Use of this source code is governed by an AGPL-style +// license that can be found in the LICENSE file. + +package tunnel + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + + "github.com/mmatczuk/go-http-tunnel/id" + "github.com/mmatczuk/go-http-tunnel/proto" + "gopkg.in/yaml.v2" +) + +type RegisteredClientInfo struct { + Hostname string + Tunnels map[string]*proto.RegisteredTunnel +} + +type ErrClientNotRegistered struct { + ClientID id.ID +} + +func (e ErrClientNotRegistered) Error() string { + return fmt.Sprintf("Client %q not registered", e.ClientID.String()) +} + +func IsNotRegistered(err error) (ok bool) { + if err == nil { + return + } + _, ok = err.(ErrClientNotRegistered) + return +} + +type RegisteredClientConfig struct { + Disabled bool + ID id.ID + Name string + Description string + Hostname string + Tunnels map[string]*proto.Tunnel + Connections int +} + +type RegisteredClientsProvider interface { + Get(clientID id.ID) (client *RegisteredClientConfig, err error) +} + +type RegisteredClientsFileSystemProvider struct { + StorageDir string +} + +func (p *RegisteredClientsFileSystemProvider) Get(clientID id.ID) (client *RegisteredClientConfig, err error) { + base := filepath.Join(p.StorageDir, clientID.String()) + pth := filepath.Join(base, "config.yaml") + var f io.ReadCloser + if f, err = os.Open(pth); err != nil { + if os.IsNotExist(err) { + return nil, &ErrClientNotRegistered{clientID} + } + return nil, err + } + defer f.Close() + var data []byte + if data, err = ioutil.ReadAll(f); err != nil { + return nil, err + } + client = &RegisteredClientConfig{} + if err = yaml.Unmarshal(data, client); err != nil { + return nil, err + } + client.ID = clientID + if client.Tunnels == nil { + client.Tunnels = map[string]*proto.Tunnel{} + } + + if _, err := os.Stat(filepath.Join(base, "disabled")); err == nil { + client.Disabled = true + } + return +} diff --git a/registered_clients_test.go b/registered_clients_test.go new file mode 100644 index 0000000..7ca95dc --- /dev/null +++ b/registered_clients_test.go @@ -0,0 +1,18 @@ +// Copyright (C) 2017 Michał Matczuk +// Use of this source code is governed by an AGPL-style +// license that can be found in the LICENSE file. + +package tunnel_test + +import ( + "github.com/mmatczuk/go-http-tunnel" + "github.com/mmatczuk/go-http-tunnel/id" +) + +type registeredClientsProvider struct { + cfg *tunnel.RegisteredClientConfig +} + +func (p registeredClientsProvider) Get(clientID id.ID) (client *tunnel.RegisteredClientConfig, err error) { + return p.cfg, nil +} diff --git a/registry.go b/registry.go index 98fead8..ef03246 100644 --- a/registry.go +++ b/registry.go @@ -9,24 +9,32 @@ import ( "net" "sync" + "github.com/mmatczuk/go-http-tunnel/proto" + "github.com/mmatczuk/go-http-tunnel/id" "github.com/mmatczuk/go-http-tunnel/log" ) +type Listener struct { + net.Listener + Tunnel *proto.Tunnel +} + // RegistryItem holds information about hosts and listeners associated with a -// client. +// controller. type RegistryItem struct { Hosts []*HostAuth - Listeners []net.Listener + Listeners []*Listener } // HostAuth holds host and authentication info. type HostAuth struct { - Host string - Auth *Auth + Tunnel *proto.Tunnel + Auth *Auth } type hostInfo struct { + tunnel *proto.Tunnel identifier id.ID auth *Auth } @@ -52,7 +60,7 @@ func newRegistry(logger log.Logger) *registry { var voidRegistryItem = &RegistryItem{} -// Subscribe allows to connect client with a given identifier. +// Subscribe allows to connect controller with a given identifier. func (r *registry) Subscribe(identifier id.ID) { r.mu.Lock() defer r.mu.Unlock() @@ -70,7 +78,7 @@ func (r *registry) Subscribe(identifier id.ID) { r.items[identifier] = voidRegistryItem } -// IsSubscribed returns true if client is subscribed. +// IsSubscribed returns true if controller is subscribed. func (r *registry) IsSubscribed(identifier id.ID) bool { r.mu.RLock() defer r.mu.RUnlock() @@ -78,8 +86,8 @@ func (r *registry) IsSubscribed(identifier id.ID) bool { return ok } -// Subscriber returns client identifier assigned to given host. -func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) { +// Subscriber returns controller identifier assigned to given host. +func (r *registry) Subscriber(hostPort string) (id.ID, *hostInfo, bool) { r.mu.RLock() defer r.mu.RUnlock() @@ -88,10 +96,10 @@ func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) { return id.ID{}, nil, false } - return h.identifier, h.auth, ok + return h.identifier, h, ok } -// Unsubscribe removes client from registry and returns it's RegistryItem. +// Unsubscribe removes controller from registry and returns it's RegistryItem. func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem { r.mu.Lock() defer r.mu.Unlock() @@ -109,7 +117,7 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem { if i.Hosts != nil { for _, h := range i.Hosts { - delete(r.hosts, h.Host) + delete(r.hosts, h.Tunnel.Host) } } @@ -141,13 +149,14 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error { if h.Auth != nil && h.Auth.User == "" { return fmt.Errorf("missing auth user") } - if _, ok := r.hosts[trimPort(h.Host)]; ok { - return fmt.Errorf("host %q is occupied", h.Host) + if _, ok := r.hosts[trimPort(h.Tunnel.Host)]; ok { + return fmt.Errorf("host %q is occupied", h.Tunnel.Host) } } for _, h := range i.Hosts { - r.hosts[trimPort(h.Host)] = &hostInfo{ + r.hosts[trimPort(h.Tunnel.Host)] = &hostInfo{ + tunnel: h.Tunnel, identifier: identifier, auth: h.Auth, } @@ -176,7 +185,7 @@ func (r *registry) clear(identifier id.ID) *RegistryItem { if i.Hosts != nil { for _, h := range i.Hosts { - delete(r.hosts, trimPort(h.Host)) + delete(r.hosts, trimPort(h.Tunnel.Host)) } } diff --git a/server.go b/server.go index b3f7bd6..ffbc2a2 100644 --- a/server.go +++ b/server.go @@ -13,6 +13,7 @@ import ( "io" "net" "net/http" + "strconv" "strings" "time" @@ -25,7 +26,7 @@ import ( // ServerConfig defines configuration for the Server. type ServerConfig struct { - // Addr is TCP address to listen for client connections. If empty ":0" + // RemoteAddr is TCP address to listen for controller connections. If empty ":0" // is used. Addr string // AutoSubscribe if enabled will automatically subscribe new clients on @@ -33,14 +34,16 @@ type ServerConfig struct { AutoSubscribe bool // TLSConfig specifies the tls configuration to use with tls.Listener. TLSConfig *tls.Config - // Listener specifies optional listener for client connections. If nil - // tls.Listen("tcp", Addr, TLSConfig) is used. + // Listener specifies optional listener for controller connections. If nil + // tls.Listen("tcp", RemoteAddr, TLSConfig) is used. Listener net.Listener // Logger is optional logger. If nil logging is disabled. Logger log.Logger + // RegisteredClient get + RegisteredClientsProvider RegisteredClientsProvider } -// Server is responsible for proxying public connections to the client over a +// Server is responsible for proxying public connections to the controller over a // tunnel connection. type Server struct { *registry @@ -91,7 +94,7 @@ func listener(config *ServerConfig) (net.Listener, error) { } if config.Addr == "" { - return nil, errors.New("missing Addr") + return nil, errors.New("missing RemoteAddr") } if config.TLSConfig == nil { return nil, errors.New("missing TLSConfig") @@ -100,8 +103,8 @@ func listener(config *ServerConfig) (net.Listener, error) { return net.Listen("tcp", config.Addr) } -// disconnected clears resources used by client, it's invoked by connection pool -// when client goes away. +// disconnected clears resources used by controller, it's invoked by connection pool +// when controller goes away. func (s *Server) disconnected(identifier id.ID) { s.logger.Log( "level", 1, @@ -165,11 +168,11 @@ func (s *Server) Start() { ) } - go s.handleClient(tls.Server(conn, s.config.TLSConfig)) + go s.handleClient(tls.Server(conn, s.config.TLSConfig), true) } } -func (s *Server) handleClient(conn net.Conn) { +func (s *Server) handleClient(conn net.Conn, main bool) { logger := log.NewContext(s.logger).With("addr", conn.RemoteAddr()) logger.Log( @@ -178,46 +181,84 @@ func (s *Server) handleClient(conn net.Conn) { ) var ( - identifier id.ID - req *http.Request - resp *http.Response - tunnels map[string]*proto.Tunnel err error - ok bool - inConnPool bool + ok bool + ID id.ID + cfg *RegisteredClientConfig + clientInfo *RegisteredClientInfo + controller *clientConnectionController + body io.Reader + req *http.Request + resp *http.Response + tlsConn *tls.Conn + ccon *clientConnection ) - tlsConn, ok := conn.(*tls.Conn) - if !ok { + if tlsConn, ok = conn.(*tls.Conn); !ok { logger.Log( "level", 0, "msg", "invalid connection type", - "err", fmt.Errorf("expected TLS conn, got %T", conn), + "err", fmt.Errorf("expected TLS controller, got %T", conn), ) - goto reject + goto Reject } - identifier, err = id.PeerID(tlsConn) + ID, err = id.PeerID(tlsConn) if err != nil { logger.Log( "level", 2, "msg", "certificate error", "err", err, ) - goto reject + goto Reject } - logger = logger.With("identifier", identifier) + logger = logger.With("identifier", ID) if s.config.AutoSubscribe { - s.Subscribe(identifier) - } else if !s.IsSubscribed(identifier) { + s.Subscribe(ID) + } else if s.config.RegisteredClientsProvider != nil { + if cfg, err = s.config.RegisteredClientsProvider.Get(ID); err == nil { + if cfg.Disabled { + logger.Log( + "level", 2, + "msg", "Client has be disabled", + ) + goto Reject + } + if s.connPool.Has(ID) { + if ccon, err = s.connPool.AddClientConnection(ID, conn); err != nil { + goto Reject + } + logger.Log( + "level", 1, + "msg", fmt.Sprintf("new connection #%d for client added", ccon.id), + ) + return + } else { + s.Subscribe(ID) + } + } else if IsNotRegistered(err) { + logger.Log( + "level", 2, + "msg", err.Error(), + ) + goto Reject + } else { + logger.Log( + "level", 2, + "msg", "Get registered controller failed", + "err", err, + ) + goto Reject + } + } else if !s.IsSubscribed(ID) { logger.Log( "level", 2, - "msg", "unknown client", + "msg", "unknown controller", ) - goto reject + goto Reject } if err = conn.SetDeadline(time.Time{}); err != nil { @@ -226,43 +267,52 @@ func (s *Server) handleClient(conn net.Conn) { "msg", "setting infinite deadline failed", "err", err, ) - goto reject + goto Reject } - if err := s.connPool.AddConn(conn, identifier); err != nil { + controller = &clientConnectionController{cfg, conn, ID, logger} + + if err := s.connPool.AddConn(controller); err != nil { logger.Log( "level", 2, "msg", "adding connection failed", "err", err, ) - goto reject + goto Reject } + inConnPool = true - req, err = http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil) + req, err = s.connPool.newRequest(http.MethodConnect, controller.ID, body) if err != nil { logger.Log( "level", 2, "msg", "handshake request creation failed", "err", err, ) - goto reject + goto Reject + } + + if cfg != nil && cfg.Connections > 0 { + req.Header.Set(HeaderConnectionsCount, strconv.Itoa(int(controller.cfg.Connections))) } { - ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout) + ctx, cancel := context.WithTimeout(req.Context(), DefaultTimeout) defer cancel() req = req.WithContext(ctx) } resp, err = s.httpClient.Do(req) + defer s.connPool.closeReqConn(req) + if err != nil { logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } if resp.StatusCode != http.StatusOK { @@ -272,7 +322,7 @@ func (s *Server) handleClient(conn net.Conn) { "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } if resp.ContentLength == 0 { @@ -282,35 +332,68 @@ func (s *Server) handleClient(conn net.Conn) { "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } - if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(&tunnels); err != nil { - logger.Log( - "level", 2, - "msg", "handshake failed", - "err", err, - ) - goto reject + if cfg == nil { + cfg = &RegisteredClientConfig{} + if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(cfg); err != nil { + logger.Log( + "level", 2, + "msg", "handshake failed", + "err", err, + ) + goto Reject + } + // Ony main connection + cfg.Connections = 0 + } else { + clientInfo = &RegisteredClientInfo{} + if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(clientInfo); err != nil { + logger.Log( + "level", 2, + "msg", "handshake failed", + "err", err, + ) + goto Reject + } + + cfg.Hostname = clientInfo.Hostname + for name, ct := range clientInfo.Tunnels { + if t, ok := cfg.Tunnels[name]; ok { + if ct.Disabled { + t.Disabled = true + } else { + t.LocalAddr = ct.LocalAddr + t.Auth = ct.Auth + } + } else { + logger.Log( + "level", 1, + "msg", fmt.Sprintf("tunnel %q not registered", name), + ) + delete(clientInfo.Tunnels, name) + } + } } - if len(tunnels) == 0 { + if len(cfg.Tunnels) == 0 { err = fmt.Errorf("No tunnels") logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } - if err = s.addTunnels(tunnels, identifier); err != nil { + if err = s.addTunnels(cfg.Tunnels, ID); err != nil { logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } logger.Log( @@ -319,32 +402,31 @@ func (s *Server) handleClient(conn net.Conn) { ) return - -reject: +Reject: logger.Log( "level", 1, "action", "rejected", ) if inConnPool { - s.notifyError(err, identifier) - s.connPool.DeleteConn(identifier) + s.notifyError(err, ID) + s.connPool.DeleteConn(ID) } conn.Close() } -// notifyError tries to send error to client. +// notifyError tries to send error to controller. func (s *Server) notifyError(serverError error, identifier id.ID) { if serverError == nil { return } - req, err := http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil) + req, err := s.connPool.newRequest(http.MethodConnect, identifier, nil) if err != nil { s.logger.Log( "level", 2, - "action", "client error notification failed", + "action", "controller error notification failed", "identifier", identifier, "err", err, ) @@ -364,17 +446,21 @@ func (s *Server) notifyError(serverError error, identifier id.ID) { func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error { i := &RegistryItem{ Hosts: []*HostAuth{}, - Listeners: []net.Listener{}, + Listeners: []*Listener{}, } var err error for name, t := range tunnels { + if t.Disabled { + continue + } + switch t.Protocol { case proto.HTTP: - i.Hosts = append(i.Hosts, &HostAuth{t.Host, NewAuth(t.Auth)}) + i.Hosts = append(i.Hosts, &HostAuth{t, NewAuth(t.Auth)}) case proto.TCP, proto.TCP4, proto.TCP6, proto.UNIX: var l net.Listener - l, err = net.Listen(t.Protocol, t.Addr) + l, err = net.Listen(t.Protocol, t.RemoteAddr) if err != nil { goto rollback } @@ -386,7 +472,7 @@ func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) "addr", l.Addr(), ) - i.Listeners = append(i.Listeners, l) + i.Listeners = append(i.Listeners, &Listener{l, t}) default: err = fmt.Errorf("unsupported protocol for tunnel %s: %s", name, t.Protocol) goto rollback @@ -412,7 +498,7 @@ rollback: return err } -// Unsubscribe removes client from registry, disconnects client if already +// Unsubscribe removes controller from registry, disconnects controller if already // connected and returns it's RegistryItem. func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem { s.connPool.DeleteConn(identifier) @@ -424,7 +510,7 @@ func (s *Server) Ping(identifier id.ID) (time.Duration, error) { return s.connPool.Ping(identifier) } -func (s *Server) listen(l net.Listener, identifier id.ID) { +func (s *Server) listen(l *Listener, identifier id.ID) { addr := l.Addr().String() for { @@ -451,6 +537,7 @@ func (s *Server) listen(l net.Listener, identifier id.ID) { } msg := &proto.ControlMessage{ + LocalAddr: l.Tunnel.LocalAddr, Action: proto.ActionProxy, ForwardedHost: l.Addr().String(), ForwardedProto: l.Addr().Network(), @@ -480,7 +567,7 @@ func (s *Server) listen(l net.Listener, identifier id.ID) { } } -// ServeHTTP proxies http connection to the client. +// ServeHTTP proxies http connection to the controller. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { resp, err := s.RoundTrip(r) if err == errUnauthorised { @@ -507,7 +594,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(resp.StatusCode) transfer(w, resp.Body, log.NewContext(s.logger).With( - "dir", "client to user", + "dir", "controller to user", "dst", r.RemoteAddr, "src", r.Host, )) @@ -515,7 +602,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // RoundTrip is http.RoundTriper implementation. func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { - identifier, auth, ok := s.Subscriber(r.Host) + identifier, hostInfo, ok := s.Subscriber(r.Host) if !ok { return nil, errClientNotSubscribed } @@ -526,9 +613,9 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { } outr.Header = cloneHeader(r.Header) - if auth != nil { + if hostInfo.auth != nil { user, password, _ := r.BasicAuth() - if auth.User != user || auth.Password != password { + if hostInfo.auth.User != user || hostInfo.auth.Password != password { return nil, errUnauthorised } outr.Header.Del("Authorization") @@ -550,6 +637,7 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { } msg := &proto.ControlMessage{ + LocalAddr: hostInfo.tunnel.LocalAddr, Action: proto.ActionProxy, ForwardedHost: r.Host, ForwardedProto: scheme, @@ -561,7 +649,7 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error { s.logger.Log( "level", 2, - "action", "proxy conn", + "action", "proxy controller", "identifier", identifier, "ctrlMsg", msg, ) @@ -576,14 +664,15 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe if err != nil { return err } + defer s.connPool.closeReqConn(req) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(req.Context()) req = req.WithContext(ctx) done := make(chan struct{}) go func() { transfer(pw, conn, log.NewContext(s.logger).With( - "dir", "user to client", + "dir", "user to controller", "dst", identifier, "src", conn.RemoteAddr(), )) @@ -595,10 +684,11 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe if err != nil { return fmt.Errorf("io error: %s", err) } + defer resp.Body.Close() transfer(conn, resp.Body, log.NewContext(s.logger).With( - "dir", "client to user", + "dir", "controller to user", "dst", conn.RemoteAddr(), "src", identifier, )) @@ -607,7 +697,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe s.logger.Log( "level", 2, - "action", "proxy conn done", + "action", "proxy controller done", "identifier", identifier, "ctrlMsg", msg, ) @@ -631,6 +721,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control if err != nil { return nil, fmt.Errorf("proxy request error: %s", err) } + defer s.connPool.closeReqConn(req) go func() { cw := &countWriter{pw, 0} @@ -650,7 +741,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control "action", "transferred", "identifier", identifier, "bytes", cw.count, - "dir", "user to client", + "dir", "user to controller", "dst", r.Host, "src", r.RemoteAddr, ) @@ -676,11 +767,11 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control return resp, nil } -// connectRequest creates HTTP request to client with a given identifier having +// connectRequest creates HTTP request to controller with a given identifier having // control message and data input stream, output data stream results from // response the created request. func (s *Server) connectRequest(identifier id.ID, msg *proto.ControlMessage, r io.Reader) (*http.Request, error) { - req, err := http.NewRequest(http.MethodPut, s.connPool.URL(identifier), r) + req, err := s.connPool.newRequest(http.MethodPut, identifier, r) if err != nil { return nil, fmt.Errorf("could not create request: %s", err) } diff --git a/tcpproxy.go b/tcpproxy.go index 613de9b..1820aee 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -68,7 +68,7 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage return } - target := p.localAddrFor(msg.ForwardedHost) + target := p.localAddrFor(msg.LocalAddr) if target == "" { p.logger.Log( "level", 1,