From 40a08660cb9533b65dc85ba491cc0eb19d988d90 Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Thu, 11 Oct 2018 11:21:41 -0300 Subject: [PATCH 01/12] Add support for client registration --- {cmd => cli}/tunnel/config.go | 31 ++-- {cmd => cli}/tunnel/normalize.go | 2 +- {cmd => cli}/tunnel/normalize_test.go | 2 +- {cmd => cli}/tunnel/options.go | 44 ++++-- {cmd => cli}/tunnel/tunnel.go | 23 +-- {cmd => cli}/tunnel/version.go | 2 +- {cmd => cli}/tunneld/banner.go | 2 +- {cmd => cli}/tunneld/options.go | 9 +- {cmd => cli}/tunneld/tunneld.go | 40 +++-- {cmd => cli}/tunneld/version.go | 2 +- client.go | 148 ++++++++++++++---- cmd/tunnel/main.go | 7 + cmd/tunneld/main.go | 7 + common.go | 17 ++ doc.go | 2 +- errors.go | 7 +- integration_test.go | 7 +- pool.go | 130 ++++++++++++++-- proto/controlmsg.go | 7 + proto/controlmsg_test.go | 15 +- proto/tunnel.go | 23 ++- registered_client.go | 76 +++++++++ registry.go | 39 +++-- server.go | 214 ++++++++++++++++++-------- tcpproxy.go | 2 +- 25 files changed, 651 insertions(+), 207 deletions(-) rename {cmd => cli}/tunnel/config.go (80%) rename {cmd => cli}/tunnel/normalize.go (98%) rename {cmd => cli}/tunnel/normalize_test.go (99%) rename {cmd => cli}/tunnel/options.go (73%) rename {cmd => cli}/tunnel/tunnel.go (92%) rename {cmd => cli}/tunnel/version.go (91%) rename {cmd => cli}/tunneld/banner.go (97%) rename {cmd => cli}/tunneld/options.go (90%) rename {cmd => cli}/tunneld/tunneld.go (76%) rename {cmd => cli}/tunneld/version.go (91%) create mode 100644 cmd/tunnel/main.go create mode 100644 cmd/tunneld/main.go create mode 100644 common.go create mode 100644 registered_client.go diff --git a/cmd/tunnel/config.go b/cli/tunnel/config.go similarity index 80% rename from cmd/tunnel/config.go rename to cli/tunnel/config.go index da20ffd..67f3fb0 100644 --- a/cmd/tunnel/config.go +++ b/cli/tunnel/config.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "fmt" @@ -33,8 +33,8 @@ type BackoffConfig struct { // Tunnel defines a tunnel. type Tunnel struct { - Protocol string `yaml:"proto,omitempty"` - Addr string `yaml:"addr,omitempty"` + Protocol string `yaml:"protocol,omitempty"` + LocalAddr string `yaml:"local_addr,omitempty"` Auth string `yaml:"auth,omitempty"` Host string `yaml:"host,omitempty"` RemoteAddr string `yaml:"remote_addr,omitempty"` @@ -42,6 +42,7 @@ type Tunnel struct { // ClientConfig is a tunnel client configuration. type ClientConfig struct { + Registered bool `yaml:"registered"` ServerAddr string `yaml:"server_addr"` TLSCrt string `yaml:"tls_crt"` TLSKey string `yaml:"tls_key"` @@ -81,11 +82,11 @@ func loadClientConfigFromFile(file string) (*ClientConfig, error) { for name, t := range c.Tunnels { switch t.Protocol { case proto.HTTP: - if err := validateHTTP(t); err != nil { + if err := validateHTTP(c.Registered, t); err != nil { return nil, fmt.Errorf("%s %s", name, err) } case proto.TCP, proto.TCP4, proto.TCP6: - if err := validateTCP(t); err != nil { + if err := validateTCP(c.Registered, t); err != nil { return nil, fmt.Errorf("%s %s", name, err) } default: @@ -96,15 +97,15 @@ func loadClientConfigFromFile(file string) (*ClientConfig, error) { return &c, nil } -func validateHTTP(t *Tunnel) error { +func validateHTTP(registered bool, t *Tunnel) error { var err error - if t.Host == "" { + if !registered && t.Host == "" { return fmt.Errorf("host: missing") } - if t.Addr == "" { + if t.LocalAddr == "" { return fmt.Errorf("addr: missing") } - if t.Addr, err = normalizeURL(t.Addr); err != nil { + if t.LocalAddr, err = normalizeURL(t.LocalAddr); err != nil { return fmt.Errorf("addr: %s", err) } @@ -117,15 +118,17 @@ func validateHTTP(t *Tunnel) error { return nil } -func validateTCP(t *Tunnel) error { +func validateTCP(registered bool, t *Tunnel) error { var err error - if t.RemoteAddr, err = normalizeAddress(t.RemoteAddr); err != nil { - return fmt.Errorf("remote_addr: %s", err) + if !registered { + if t.RemoteAddr, err = normalizeAddress(t.RemoteAddr); err != nil { + return fmt.Errorf("remote_addr: %s", err) + } } - if t.Addr == "" { + if t.LocalAddr == "" { return fmt.Errorf("addr: missing") } - if t.Addr, err = normalizeAddress(t.Addr); err != nil { + if t.LocalAddr, err = normalizeAddress(t.LocalAddr); err != nil { return fmt.Errorf("addr: %s", err) } diff --git a/cmd/tunnel/normalize.go b/cli/tunnel/normalize.go similarity index 98% rename from cmd/tunnel/normalize.go rename to cli/tunnel/normalize.go index 2f9f84d..d205b6d 100644 --- a/cmd/tunnel/normalize.go +++ b/cli/tunnel/normalize.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "fmt" diff --git a/cmd/tunnel/normalize_test.go b/cli/tunnel/normalize_test.go similarity index 99% rename from cmd/tunnel/normalize_test.go rename to cli/tunnel/normalize_test.go index 5db7eb1..ef65c2d 100644 --- a/cmd/tunnel/normalize_test.go +++ b/cli/tunnel/normalize_test.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "strings" diff --git a/cmd/tunnel/options.go b/cli/tunnel/options.go similarity index 73% rename from cmd/tunnel/options.go rename to cli/tunnel/options.go index ef3176f..b70290d 100644 --- a/cmd/tunnel/options.go +++ b/cli/tunnel/options.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "flag" @@ -26,19 +26,31 @@ Examples: tunnel -config config.yaml -log-level 2 start ssh tunnel start-all -config.yaml: +Normal Client config: +client.yaml: server_addr: SERVER_IP:5223 tunnels: webui: proto: http - addr: localhost:8080 + local_addr: localhost:8080 auth: user:password host: webui.my-tunnel-host.com ssh: proto: tcp - addr: 192.168.0.5:22 + local_addr: 192.168.0.5:22 remote_addr: 0.0.0.0:22 +Registered Client config: +registered_client.yaml: + registered: true + server_addr: SERVER_IP:5223 + tunnels: + webui: + local_addr: localhost:8080 + auth: user:password + ssh: + local_addr: 192.168.0.5:22 + Author: Written by M. Matczuk (mmatczuk@gmail.com) @@ -55,24 +67,24 @@ func init() { } type options struct { - config string - logLevel int - version bool - command string - args []string + config string + logLevel int + version bool + command string + args []string } -func parseArgs() (*options, error) { - config := flag.String("config", "tunnel.yml", "Path to tunnel configuration file") +func parseArgs(args ...string) (*options, error) { + config := flag.String("config", "tunnel.yaml", "Path to tunnel configuration file") logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3") version := flag.Bool("version", false, "Prints tunnel version") - flag.Parse() + flag.CommandLine.Parse(args) opts := &options{ - config: *config, - logLevel: *logLevel, - version: *version, - command: flag.Arg(0), + config: *config, + logLevel: *logLevel, + version: *version, + command: flag.Arg(0), } if opts.version { diff --git a/cmd/tunnel/tunnel.go b/cli/tunnel/tunnel.go similarity index 92% rename from cmd/tunnel/tunnel.go rename to cli/tunnel/tunnel.go index 9bce9ea..9ebbbc3 100644 --- a/cmd/tunnel/tunnel.go +++ b/cli/tunnel/tunnel.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel import ( "crypto/tls" @@ -23,8 +23,12 @@ import ( "github.com/mmatczuk/go-http-tunnel/proto" ) -func main() { - opts, err := parseArgs() +func Main() { + MainArgs(os.Args[1:]...) +} + +func MainArgs(args ...string) { + opts, err := parseArgs(args...) if err != nil { fatal(err.Error()) } @@ -158,10 +162,11 @@ func tunnels(m map[string]*Tunnel) map[string]*proto.Tunnel { for name, t := range m { p[name] = &proto.Tunnel{ - Protocol: t.Protocol, - Host: t.Host, - Auth: t.Auth, - Addr: t.RemoteAddr, + Protocol: t.Protocol, + Host: t.Host, + Auth: t.Auth, + LocalAddr: t.LocalAddr, + RemoteAddr: t.RemoteAddr, } } @@ -175,13 +180,13 @@ func proxy(m map[string]*Tunnel, logger log.Logger) tunnel.ProxyFunc { for _, t := range m { switch t.Protocol { case proto.HTTP: - u, err := url.Parse(t.Addr) + u, err := url.Parse(t.LocalAddr) if err != nil { fatal("invalid tunnel address: %s", err) } httpURL[t.Host] = u case proto.TCP, proto.TCP4, proto.TCP6: - tcpAddr[t.RemoteAddr] = t.Addr + tcpAddr[t.LocalAddr] = t.LocalAddr } } diff --git a/cmd/tunnel/version.go b/cli/tunnel/version.go similarity index 91% rename from cmd/tunnel/version.go rename to cli/tunnel/version.go index 189c430..8f12ee6 100644 --- a/cmd/tunnel/version.go +++ b/cli/tunnel/version.go @@ -2,6 +2,6 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunnel var version = "snapshot" diff --git a/cmd/tunneld/banner.go b/cli/tunneld/banner.go similarity index 97% rename from cmd/tunneld/banner.go rename to cli/tunneld/banner.go index c5cf848..d4f88c1 100644 --- a/cmd/tunneld/banner.go +++ b/cli/tunneld/banner.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunneld const banner = ` ______ __ __________________ __ __ diff --git a/cmd/tunneld/options.go b/cli/tunneld/options.go similarity index 90% rename from cmd/tunneld/options.go rename to cli/tunneld/options.go index 06994dc..331824a 100644 --- a/cmd/tunneld/options.go +++ b/cli/tunneld/options.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunneld import ( "flag" @@ -44,11 +44,12 @@ type options struct { tlsKey string rootCA string clients string + clientsDir string logLevel int version bool } -func parseArgs() *options { +func parseArgs(args ...string) *options { httpAddr := flag.String("httpAddr", ":80", "Public address for HTTP connections, empty string to disable") httpsAddr := flag.String("httpsAddr", ":443", "Public address listening for HTTPS connections, emptry string to disable") tunnelAddr := flag.String("tunnelAddr", ":5223", "Public address listening for tunnel client") @@ -56,9 +57,10 @@ func parseArgs() *options { tlsKey := flag.String("tlsKey", "server.key", "Path to a TLS key file") rootCA := flag.String("rootCA", "", "Path to the trusted certificate chian used for client certificate authentication, if empty any client certificate is accepted") clients := flag.String("clients", "", "Comma-separated list of tunnel client ids, if empty accept all clients") + clientsDir := flag.String("clientsDir", "", "Directory of registered clients") logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3") version := flag.Bool("version", false, "Prints tunneld version") - flag.Parse() + flag.CommandLine.Parse(args) return &options{ httpAddr: *httpAddr, @@ -68,6 +70,7 @@ func parseArgs() *options { tlsKey: *tlsKey, rootCA: *rootCA, clients: *clients, + clientsDir: *clientsDir, logLevel: *logLevel, version: *version, } diff --git a/cmd/tunneld/tunneld.go b/cli/tunneld/tunneld.go similarity index 76% rename from cmd/tunneld/tunneld.go rename to cli/tunneld/tunneld.go index d7c1eba..c577bb6 100644 --- a/cmd/tunneld/tunneld.go +++ b/cli/tunneld/tunneld.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunneld import ( "crypto/tls" @@ -20,8 +20,12 @@ import ( "github.com/mmatczuk/go-http-tunnel/log" ) -func main() { - opts := parseArgs() +func Main() { + MainArgs(os.Args[1:]...) +} + +func MainArgs(args ...string) { + opts := parseArgs(args...) if opts.version { fmt.Println(version) @@ -37,7 +41,12 @@ func main() { fatal("failed to configure tls: %s", err) } - autoSubscribe := opts.clients == "" + autoSubscribe := opts.clients == "" && opts.clientsDir == "" + + var clientsProvider tunnel.RegisteredClientsProvider + if opts.clientsDir != "" { + clientsProvider = &tunnel.RegisteredClientsFileSystemProvider{opts.clientsDir} + } // setup server server, err := tunnel.NewServer(&tunnel.ServerConfig{ @@ -45,22 +54,29 @@ func main() { AutoSubscribe: autoSubscribe, TLSConfig: tlsconf, Logger: logger, + RegisteredClientsProvider: clientsProvider, }) if err != nil { fatal("failed to create server: %s", err) } if !autoSubscribe { - for _, c := range strings.Split(opts.clients, ",") { - if c == "" { - fatal("empty client id") + if opts.clientsDir != "" { + if _, err := os.Stat(opts.clientsDir); err != nil { + fatal(err.Error()) } - identifier := id.ID{} - err := identifier.UnmarshalText([]byte(c)) - if err != nil { - fatal("invalid identifier %q: %s", c, err) + } else { + for _, c := range strings.Split(opts.clients, ",") { + if c == "" { + fatal("empty client id") + } + identifier := id.ID{} + err := identifier.UnmarshalText([]byte(c)) + if err != nil { + fatal("invalid identifier %q: %s", c, err) + } + server.Subscribe(identifier) } - server.Subscribe(identifier) } } diff --git a/cmd/tunneld/version.go b/cli/tunneld/version.go similarity index 91% rename from cmd/tunneld/version.go rename to cli/tunneld/version.go index 189c430..696adfe 100644 --- a/cmd/tunneld/version.go +++ b/cli/tunneld/version.go @@ -2,6 +2,6 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -package main +package tunneld var version = "snapshot" diff --git a/client.go b/client.go index 3647ad7..4033c11 100644 --- a/client.go +++ b/client.go @@ -11,11 +11,14 @@ import ( "fmt" "net" "net/http" + "os" "sync" "time" "golang.org/x/net/http2" + "strconv" + "github.com/mmatczuk/go-http-tunnel/log" "github.com/mmatczuk/go-http-tunnel/proto" ) @@ -33,27 +36,41 @@ type ClientConfig struct { // Backoff specifies backoff policy on server connection retry. If nil // when dial fails it will not be retried. Backoff Backoff - // Tunnels specifies the tunnels client requests to be opened on server. + // Tunnels specifies the tunnels controller requests to be opened on server. Tunnels map[string]*proto.Tunnel // Proxy is ProxyFunc responsible for transferring data between server // and local services. Proxy ProxyFunc // Logger is optional logger. If nil logging is disabled. Logger log.Logger + // Name the controller name + Name string + // Description Client Description + Description string + // Registered if client is registered on server + Registered bool } -// Client is responsible for creating connection to the server, handling control -// messages. It uses ProxyFunc for transferring data between server and local -// services. -type Client struct { - config *ClientConfig - +type ClientToServerConn struct { + c *Client conn net.Conn connMu sync.Mutex httpServer *http2.Server serverErr error lastDisconnect time.Time logger log.Logger + handler http.HandlerFunc +} + +// Client is responsible for creating connection to the server, handling control +// messages. It uses ProxyFunc for transferring data between server and local +// services. +type Client struct { + *ClientToServerConn + config *ClientConfig + connectionCount int + connections []*ClientToServerConn + hostname string } // NewClient creates a new unconnected Client based on configuration. Caller @@ -78,19 +95,40 @@ func NewClient(config *ClientConfig) (*Client, error) { } c := &Client{ - config: config, + config: config, + } + + c.hostname, _ = os.Hostname() + + c.ClientToServerConn = &ClientToServerConn{ + c: c, httpServer: &http2.Server{}, logger: logger, - } + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodConnect { + if r.Header.Get(proto.HeaderError) != "" { + c.handleHandshakeError(w, r) + } else { + c.handleHandshake(w, r) + } + return + } + c.ClientToServerConn.serveHTTP(w, r) + }} return c, nil } -// Start connects client to the server, it returns error if there is a +// Hostname returns the client hostname +func (c *Client) Hostname() string { + return c.hostname +} + +// Start connects controller to the server, it returns error if there is a // connection error, or server cannot open requested tunnels. On connection // error a backoff policy is used to reestablish the connection. When connected // HTTP/2 server is started to handle ControlMessages. -func (c *Client) Start() error { +func (c *ClientToServerConn) Start() error { c.logger.Log( "level", 1, "action", "start", @@ -102,8 +140,14 @@ func (c *Client) Start() error { return err } + handler := c.handler + + if handler == nil { + handler = http.HandlerFunc(c.serveHTTP) + } + c.httpServer.ServeConn(conn, &http2.ServeConnOpts{ - Handler: http.HandlerFunc(c.serveHTTP), + Handler: handler, }) c.logger.Log( @@ -131,7 +175,7 @@ func (c *Client) Start() error { } } -func (c *Client) connect() (net.Conn, error) { +func (c *ClientToServerConn) connect() (net.Conn, error) { c.connMu.Lock() defer c.connMu.Unlock() @@ -148,11 +192,11 @@ func (c *Client) connect() (net.Conn, error) { return conn, nil } -func (c *Client) dial() (net.Conn, error) { +func (c *ClientToServerConn) dial() (net.Conn, error) { var ( network = "tcp" - addr = c.config.ServerAddr - tlsConfig = c.config.TLSClientConfig + addr = c.c.config.ServerAddr + tlsConfig = c.c.config.TLSClientConfig ) doDial := func() (conn net.Conn, err error) { @@ -163,8 +207,8 @@ func (c *Client) dial() (net.Conn, error) { "addr", addr, ) - if c.config.DialTLS != nil { - conn, err = c.config.DialTLS(network, addr, tlsConfig) + if c.c.config.DialTLS != nil { + conn, err = c.c.config.DialTLS(network, addr, tlsConfig) } else { d := &net.Dialer{ Timeout: DefaultTimeout, @@ -200,7 +244,7 @@ func (c *Client) dial() (net.Conn, error) { return } - b := c.config.Backoff + b := c.c.config.Backoff if b == nil { return doDial() } @@ -230,16 +274,7 @@ func (c *Client) dial() (net.Conn, error) { } } -func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodConnect { - if r.Header.Get(proto.HeaderError) != "" { - c.handleHandshakeError(w, r) - } else { - c.handleHandshake(w, r) - } - return - } - +func (c *ClientToServerConn) serveHTTP(w http.ResponseWriter, r *http.Request) { msg, err := proto.ReadControlMessage(r) if err != nil { c.logger.Log( @@ -257,7 +292,7 @@ func (c *Client) serveHTTP(w http.ResponseWriter, r *http.Request) { ) switch msg.Action { case proto.ActionProxy: - c.config.Proxy(w, r.Body, msg) + c.c.config.Proxy(w, r.Body, msg) default: c.logger.Log( "level", 0, @@ -297,7 +332,37 @@ func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - b, err := json.Marshal(c.config.Tunnels) + var info interface{} + + if c.config.Registered { + if cc := r.Header.Get(HeaderConnectionsCount); cc != "" { + c.connectionCount, _ = strconv.Atoi(cc) + } + + tunnels := map[string]*proto.RegisteredTunnel{} + for name, t := range c.config.Tunnels { + tunnels[name] = &proto.RegisteredTunnel{ + Disabled: t.Disabled, + LocalAddr: t.LocalAddr, + Auth: t.Auth, + } + } + + info = &RegisteredClientInfo{ + Hostname: c.hostname, + Tunnels: tunnels, + } + } else { + info = &RegisteredClientConfig{ + Name: c.config.Name, + Description: c.config.Description, + Tunnels: c.config.Tunnels, + Hostname: c.hostname, + } + } + + b, err := json.Marshal(info) + if err != nil { c.logger.Log( "level", 0, @@ -306,11 +371,26 @@ func (c *Client) handleHandshake(w http.ResponseWriter, r *http.Request) { ) return } - w.Write(b) + + if _, err := w.Write(b); err == nil && c.config.Registered && c.connectionCount > 0 { + if len(c.connections) == 0 { + for i := 0; i < c.connectionCount; i++ { + c.connections = append(c.connections, &ClientToServerConn{ + c: c, + httpServer: c.httpServer, + logger: log.NewContext(c.logger).With(fmt.Sprint("[", i, "]")), + }) + } + + for _, cc := range c.connections { + go cc.Start() + } + } + } } -// Stop disconnects client from server. -func (c *Client) Stop() { +// Stop disconnects controller from server. +func (c *ClientToServerConn) Stop() { c.connMu.Lock() defer c.connMu.Unlock() diff --git a/cmd/tunnel/main.go b/cmd/tunnel/main.go new file mode 100644 index 0000000..21b02a2 --- /dev/null +++ b/cmd/tunnel/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/mmatczuk/go-http-tunnel/cli/tunnel" + +func main() { + tunnel.Main() +} diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go new file mode 100644 index 0000000..119a380 --- /dev/null +++ b/cmd/tunneld/main.go @@ -0,0 +1,7 @@ +package main + +import "github.com/mmatczuk/go-http-tunnel/cli/tunneld" + +func main() { + tunneld.Main() +} diff --git a/common.go b/common.go new file mode 100644 index 0000000..900263c --- /dev/null +++ b/common.go @@ -0,0 +1,17 @@ +package tunnel + +import ( + "net" + + "github.com/mmatczuk/go-http-tunnel/id" + "github.com/mmatczuk/go-http-tunnel/log" +) + +const HeaderConnectionsCount = "X-Connections-Count" + +type clientConnectionController struct { + cfg *RegisteredClientConfig + conn net.Conn + ID id.ID + logger log.Logger +} diff --git a/doc.go b/doc.go index 091aae9..279adcb 100644 --- a/doc.go +++ b/doc.go @@ -2,7 +2,7 @@ // Use of this source code is governed by an AGPL-style // license that can be found in the LICENSE file. -// Package tunnel is fast and secure client/server package that enables proxying +// Package tunnel is fast and secure controller/server package that enables proxying // public connections to your local machine over a tunnel connection from the // local machine to the public server. package tunnel diff --git a/errors.go b/errors.go index 1a6b53b..ea36386 100644 --- a/errors.go +++ b/errors.go @@ -7,9 +7,10 @@ package tunnel import "errors" var ( - errClientNotSubscribed = errors.New("client not subscribed") - errClientNotConnected = errors.New("client not connected") - errClientAlreadyConnected = errors.New("client already connected") + errClientManyConnections = errors.New("controller has many connections") + errClientNotSubscribed = errors.New("controller not subscribed") + errClientNotConnected = errors.New("controller not connected") + errClientAlreadyConnected = errors.New("controller already connected") errUnauthorised = errors.New("unauthorised") ) diff --git a/integration_test.go b/integration_test.go index 78fc039..3a9a3da 100644 --- a/integration_test.go +++ b/integration_test.go @@ -121,10 +121,11 @@ func makeTunnelClient(t testing.TB, serverAddr string, httpLocalAddr, httpAddr, Protocol: proto.HTTP, Host: "localhost", Auth: "user:password", + LocalAddr: httpLocalAddr.String(), }, proto.TCP: { Protocol: proto.TCP, - Addr: tcpLocalAddr.String(), + LocalAddr: tcpLocalAddr.String(), }, } @@ -165,12 +166,12 @@ func TestIntegration(t *testing.T) { httpLocalAddr := h.Listener.Addr() tcpLocalAddr := freeAddr() - // client + // controller c := makeTunnelClient(t, s.Addr(), httpLocalAddr, http.Addr(), tcpLocalAddr, tcp.Addr(), ) - // FIXME: replace sleep with client state change watch when ready + // FIXME: replace sleep with controller state change watch when ready time.Sleep(500 * time.Millisecond) defer c.Stop() diff --git a/pool.go b/pool.go index 2219b3f..219328b 100644 --- a/pool.go +++ b/pool.go @@ -17,14 +17,83 @@ import ( "github.com/mmatczuk/go-http-tunnel/id" ) +type clientConnections struct { + first *clientConnection + last *clientConnection + count int + mu sync.Mutex + lastid int +} + +func (cs *clientConnections) add(con *http2.ClientConn) *clientConnection { + cs.mu.Lock() + defer cs.mu.Unlock() + + c := &clientConnection{con, cs.last, nil, cs.lastid} + cs.lastid++ + if cs.last == nil { + cs.first = c + } else { + c.prev.next = c + } + cs.last = c + cs.count++ + return cs.last +} + +func (cs *clientConnections) remove(c *clientConnection) { + cs.mu.Lock() + defer cs.mu.Unlock() + + if cs.count == 1 { + cs.first = nil + cs.last = nil + } else if cs.last == c { + cs.last = c.prev + cs.last.next = nil + } else if cs.first == c { + cs.first = c.next + cs.first.prev = nil + } else { + p, n := c.prev, c.next + p.next = n + n.prev = p + } + + c.prev = nil + c.next = nil + cs.count-- +} + +type clientConnection struct { + *http2.ClientConn + prev *clientConnection + next *clientConnection + id int +} + type connPair struct { - conn net.Conn - clientConn *http2.ClientConn + controller *clientConnectionController + clientConn *clientConnection + conns *clientConnections + current *clientConnection + mu sync.Mutex +} + +func (cp *connPair) next() *clientConnection { + cp.mu.Lock() + defer cp.mu.Unlock() + if cp.current == nil || cp.current.next == nil { + cp.current = cp.conns.first + } else { + cp.current = cp.current.next + } + return cp.current } type connPool struct { t *http2.Transport - conns map[string]connPair // key is host:port + conns map[string]*connPair // key is host:port free func(identifier id.ID) mu sync.RWMutex } @@ -33,7 +102,7 @@ func newConnPool(t *http2.Transport, f func(identifier id.ID)) *connPool { return &connPool{ t: t, free: f, - conns: make(map[string]connPair), + conns: make(map[string]*connPair), } } @@ -45,8 +114,10 @@ func (p *connPool) GetClientConn(req *http.Request, addr string) (*http2.ClientC p.mu.RLock() defer p.mu.RUnlock() - if cp, ok := p.conns[addr]; ok && cp.clientConn.CanTakeNewRequest() { - return cp.clientConn, nil + if cp, ok := p.conns[addr]; ok { + conn := cp.next() + cp.controller.logger.Log("level", 3, "client_conn", fmt.Sprintf("#%d", conn.id), "addr", addr) + return conn.ClientConn, nil } return nil, errClientNotConnected @@ -57,18 +128,44 @@ func (p *connPool) MarkDead(c *http2.ClientConn) { defer p.mu.Unlock() for addr, cp := range p.conns { - if cp.clientConn == c { + if cp.clientConn.ClientConn == c { p.close(cp, addr) return } } } -func (p *connPool) AddConn(conn net.Conn, identifier id.ID) error { +func (p *connPool) Has(ID id.ID) (ok bool) { p.mu.Lock() defer p.mu.Unlock() - addr := p.addr(identifier) + addr := p.addr(ID) + _, ok = p.conns[addr] + return +} + +func (p *connPool) AddClientConnection(ID id.ID, conn net.Conn) (ccon *clientConnection, err error) { + c, err := p.t.NewClientConn(conn) + if err != nil { + return nil, err + } + + addr := p.addr(ID) + + cp := p.conns[addr] + if (cp.conns.count - 1) >= cp.controller.cfg.Connections { + return nil, errClientManyConnections + } + p.mu.Lock() + defer p.mu.Unlock() + return cp.conns.add(c), nil +} + +func (p *connPool) AddConn(controller *clientConnectionController) error { + p.mu.Lock() + defer p.mu.Unlock() + + addr := p.addr(controller.ID) if cp, ok := p.conns[addr]; ok { if err := p.ping(cp); err != nil { @@ -78,14 +175,13 @@ func (p *connPool) AddConn(conn net.Conn, identifier id.ID) error { } } - c, err := p.t.NewClientConn(conn) + c, err := p.t.NewClientConn(controller.conn) if err != nil { return err } - p.conns[addr] = connPair{ - conn: conn, - clientConn: c, - } + cp := &connPair{controller: controller, conns: &clientConnections{}} + cp.clientConn = cp.conns.add(c) + p.conns[addr] = cp return nil } @@ -116,15 +212,15 @@ func (p *connPool) Ping(identifier id.ID) (time.Duration, error) { return 0, errClientNotConnected } -func (p *connPool) ping(cp connPair) error { +func (p *connPool) ping(cp *connPair) error { ctx, cancel := context.WithTimeout(context.Background(), DefaultPingTimeout) defer cancel() return cp.clientConn.Ping(ctx) } -func (p *connPool) close(cp connPair, addr string) { - cp.conn.Close() +func (p *connPool) close(cp *connPair, addr string) { + cp.controller.conn.Close() delete(p.conns, addr) if p.free != nil { p.free(p.identifier(addr)) diff --git a/proto/controlmsg.go b/proto/controlmsg.go index 2dddc93..aaef5f1 100644 --- a/proto/controlmsg.go +++ b/proto/controlmsg.go @@ -13,6 +13,7 @@ import ( const ( HeaderError = "X-Error" + HeaderLocalAddr = "X-Local-RemoteAddr" HeaderAction = "X-Action" HeaderForwardedHost = "X-Forwarded-Host" HeaderForwardedProto = "X-Forwarded-Proto" @@ -38,6 +39,7 @@ const ( // used to inform client about the data and action to take. Based on that client // routes requests to backend services. type ControlMessage struct { + LocalAddr string Action string ForwardedHost string ForwardedProto string @@ -47,6 +49,7 @@ type ControlMessage struct { // ReadControlMessage reads ControlMessage from HTTP headers. func ReadControlMessage(r *http.Request) (*ControlMessage, error) { msg := ControlMessage{ + LocalAddr: r.Header.Get(HeaderLocalAddr), Action: r.Header.Get(HeaderAction), ForwardedHost: r.Header.Get(HeaderForwardedHost), ForwardedProto: r.Header.Get(HeaderForwardedProto), @@ -55,6 +58,9 @@ func ReadControlMessage(r *http.Request) (*ControlMessage, error) { var missing []string + if msg.LocalAddr == "" { + missing = append(missing, HeaderLocalAddr) + } if msg.Action == "" { missing = append(missing, HeaderAction) } @@ -74,6 +80,7 @@ func ReadControlMessage(r *http.Request) (*ControlMessage, error) { // WriteToHeader writes ControlMessage to HTTP header. func (c *ControlMessage) WriteToHeader(h http.Header) { + h.Set(HeaderLocalAddr, c.LocalAddr) h.Set(HeaderAction, string(c.Action)) h.Set(HeaderForwardedHost, c.ForwardedHost) h.Set(HeaderForwardedProto, c.ForwardedProto) diff --git a/proto/controlmsg_test.go b/proto/controlmsg_test.go index 51f39e0..564ca39 100644 --- a/proto/controlmsg_test.go +++ b/proto/controlmsg_test.go @@ -20,32 +20,41 @@ func TestControlMessageWriteRead(t *testing.T) { }{ { &ControlMessage{ + LocalAddr: "local_addr", Action: "action", ForwardedHost: "forwarded_host", ForwardedProto: "forwarded_proto", }, nil, }, + { + &ControlMessage{ + Action: "action", + ForwardedHost: "forwarded_host", + ForwardedProto: "forwarded_proto", + }, + errors.New("missing headers: [X-Local-RemoteAddr]"), + }, { &ControlMessage{ ForwardedHost: "forwarded_host", ForwardedProto: "forwarded_proto", }, - errors.New("missing headers: [X-Action]"), + errors.New("missing headers: [X-Local-RemoteAddr X-Action]"), }, { &ControlMessage{ Action: "action", ForwardedHost: "forwarded_host", }, - errors.New("missing headers: [X-Forwarded-Proto]"), + errors.New("missing headers: [X-Local-RemoteAddr X-Forwarded-Proto]"), }, { &ControlMessage{ Action: "action", ForwardedProto: "forwarded_proto", }, - errors.New("missing headers: [X-Forwarded-Host]"), + errors.New("missing headers: [X-Local-RemoteAddr X-Forwarded-Host]"), }, } diff --git a/proto/tunnel.go b/proto/tunnel.go index 2b5e71e..b2e38dc 100644 --- a/proto/tunnel.go +++ b/proto/tunnel.go @@ -6,8 +6,10 @@ package proto // Tunnel describes a single tunnel between client and server. When connecting // client sends tunnels to server. If client gets connected server proxies -// connections to given Host and Addr to the client. +// connections to given Host and RemoteAddr to the client. type Tunnel struct { + // Disabled tunnel disabled + Disabled bool // Protocol specifies tunnel protocol, must be one of protocols known // by the server. Protocol string @@ -17,7 +19,22 @@ type Tunnel struct { // Auth specifies HTTP basic auth credentials in form "user:password", // if set server would protect HTTP and WS tunnels with basic auth. Auth string - // Addr specifies TCP address server would listen on, it's required + // RemoteAddr specifies TCP address server would listen on, it's required // for TCP tunnels. - Addr string + RemoteAddr string `yaml:"remote_addr"` + // LocalAddr client local addr + LocalAddr string `yaml:"local_addr"` +} + +// RegisteredTunnel describes a single registered tunnel between client and server. When connecting +// registered client sends tunnels to server. If client gets connected server proxies +// connections to given Host and LocalAddr to the client. +type RegisteredTunnel struct { + // Disabled tunnel disabled + Disabled bool + // Auth specifies HTTP basic auth credentials in form "user:password", + // if set server would protect HTTP and WS tunnels with basic auth. + Auth string + // LocalAddr client local addr + LocalAddr string `yaml:"local_addr"` } diff --git a/registered_client.go b/registered_client.go new file mode 100644 index 0000000..bd36e17 --- /dev/null +++ b/registered_client.go @@ -0,0 +1,76 @@ +package tunnel + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + + "github.com/mmatczuk/go-http-tunnel/id" + "github.com/mmatczuk/go-http-tunnel/proto" + "gopkg.in/yaml.v2" +) + +type RegisteredClientInfo struct { + Hostname string + Tunnels map[string]*proto.RegisteredTunnel +} + +type ErrClientNotRegistered struct { + ClientID id.ID +} + +func (e ErrClientNotRegistered) Error() string { + return fmt.Sprintf("Client %q not registered", e.ClientID.String()) +} + +func IsNotRegistered(err error) (ok bool) { + if err == nil { + return + } + _, ok = err.(ErrClientNotRegistered) + return +} + +type RegisteredClientConfig struct { + ID id.ID + Name string + Description string + Hostname string + Tunnels map[string]*proto.Tunnel + Connections int +} + +type RegisteredClientsProvider interface { + Get(clientID id.ID) (client *RegisteredClientConfig, err error) +} + +type RegisteredClientsFileSystemProvider struct { + StorageDir string +} + +func (p *RegisteredClientsFileSystemProvider) Get(clientID id.ID) (client *RegisteredClientConfig, err error) { + pth := filepath.Join(p.StorageDir, clientID.String()+".yaml") + var f io.ReadCloser + if f, err = os.Open(pth); err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + defer f.Close() + var data []byte + if data, err = ioutil.ReadAll(f); err != nil { + return nil, err + } + client = &RegisteredClientConfig{} + if err = yaml.Unmarshal(data, client); err != nil { + return nil, err + } + client.ID = clientID + if client.Tunnels == nil { + client.Tunnels = map[string]*proto.Tunnel{} + } + return +} diff --git a/registry.go b/registry.go index 98fead8..ef03246 100644 --- a/registry.go +++ b/registry.go @@ -9,24 +9,32 @@ import ( "net" "sync" + "github.com/mmatczuk/go-http-tunnel/proto" + "github.com/mmatczuk/go-http-tunnel/id" "github.com/mmatczuk/go-http-tunnel/log" ) +type Listener struct { + net.Listener + Tunnel *proto.Tunnel +} + // RegistryItem holds information about hosts and listeners associated with a -// client. +// controller. type RegistryItem struct { Hosts []*HostAuth - Listeners []net.Listener + Listeners []*Listener } // HostAuth holds host and authentication info. type HostAuth struct { - Host string - Auth *Auth + Tunnel *proto.Tunnel + Auth *Auth } type hostInfo struct { + tunnel *proto.Tunnel identifier id.ID auth *Auth } @@ -52,7 +60,7 @@ func newRegistry(logger log.Logger) *registry { var voidRegistryItem = &RegistryItem{} -// Subscribe allows to connect client with a given identifier. +// Subscribe allows to connect controller with a given identifier. func (r *registry) Subscribe(identifier id.ID) { r.mu.Lock() defer r.mu.Unlock() @@ -70,7 +78,7 @@ func (r *registry) Subscribe(identifier id.ID) { r.items[identifier] = voidRegistryItem } -// IsSubscribed returns true if client is subscribed. +// IsSubscribed returns true if controller is subscribed. func (r *registry) IsSubscribed(identifier id.ID) bool { r.mu.RLock() defer r.mu.RUnlock() @@ -78,8 +86,8 @@ func (r *registry) IsSubscribed(identifier id.ID) bool { return ok } -// Subscriber returns client identifier assigned to given host. -func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) { +// Subscriber returns controller identifier assigned to given host. +func (r *registry) Subscriber(hostPort string) (id.ID, *hostInfo, bool) { r.mu.RLock() defer r.mu.RUnlock() @@ -88,10 +96,10 @@ func (r *registry) Subscriber(hostPort string) (id.ID, *Auth, bool) { return id.ID{}, nil, false } - return h.identifier, h.auth, ok + return h.identifier, h, ok } -// Unsubscribe removes client from registry and returns it's RegistryItem. +// Unsubscribe removes controller from registry and returns it's RegistryItem. func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem { r.mu.Lock() defer r.mu.Unlock() @@ -109,7 +117,7 @@ func (r *registry) Unsubscribe(identifier id.ID) *RegistryItem { if i.Hosts != nil { for _, h := range i.Hosts { - delete(r.hosts, h.Host) + delete(r.hosts, h.Tunnel.Host) } } @@ -141,13 +149,14 @@ func (r *registry) set(i *RegistryItem, identifier id.ID) error { if h.Auth != nil && h.Auth.User == "" { return fmt.Errorf("missing auth user") } - if _, ok := r.hosts[trimPort(h.Host)]; ok { - return fmt.Errorf("host %q is occupied", h.Host) + if _, ok := r.hosts[trimPort(h.Tunnel.Host)]; ok { + return fmt.Errorf("host %q is occupied", h.Tunnel.Host) } } for _, h := range i.Hosts { - r.hosts[trimPort(h.Host)] = &hostInfo{ + r.hosts[trimPort(h.Tunnel.Host)] = &hostInfo{ + tunnel: h.Tunnel, identifier: identifier, auth: h.Auth, } @@ -176,7 +185,7 @@ func (r *registry) clear(identifier id.ID) *RegistryItem { if i.Hosts != nil { for _, h := range i.Hosts { - delete(r.hosts, trimPort(h.Host)) + delete(r.hosts, trimPort(h.Tunnel.Host)) } } diff --git a/server.go b/server.go index b3f7bd6..a562871 100644 --- a/server.go +++ b/server.go @@ -13,6 +13,7 @@ import ( "io" "net" "net/http" + "strconv" "strings" "time" @@ -25,7 +26,7 @@ import ( // ServerConfig defines configuration for the Server. type ServerConfig struct { - // Addr is TCP address to listen for client connections. If empty ":0" + // RemoteAddr is TCP address to listen for controller connections. If empty ":0" // is used. Addr string // AutoSubscribe if enabled will automatically subscribe new clients on @@ -33,14 +34,16 @@ type ServerConfig struct { AutoSubscribe bool // TLSConfig specifies the tls configuration to use with tls.Listener. TLSConfig *tls.Config - // Listener specifies optional listener for client connections. If nil - // tls.Listen("tcp", Addr, TLSConfig) is used. + // Listener specifies optional listener for controller connections. If nil + // tls.Listen("tcp", RemoteAddr, TLSConfig) is used. Listener net.Listener // Logger is optional logger. If nil logging is disabled. Logger log.Logger + // RegisteredClient get + RegisteredClientsProvider RegisteredClientsProvider } -// Server is responsible for proxying public connections to the client over a +// Server is responsible for proxying public connections to the controller over a // tunnel connection. type Server struct { *registry @@ -91,7 +94,7 @@ func listener(config *ServerConfig) (net.Listener, error) { } if config.Addr == "" { - return nil, errors.New("missing Addr") + return nil, errors.New("missing RemoteAddr") } if config.TLSConfig == nil { return nil, errors.New("missing TLSConfig") @@ -100,8 +103,8 @@ func listener(config *ServerConfig) (net.Listener, error) { return net.Listen("tcp", config.Addr) } -// disconnected clears resources used by client, it's invoked by connection pool -// when client goes away. +// disconnected clears resources used by controller, it's invoked by connection pool +// when controller goes away. func (s *Server) disconnected(identifier id.ID) { s.logger.Log( "level", 1, @@ -165,11 +168,11 @@ func (s *Server) Start() { ) } - go s.handleClient(tls.Server(conn, s.config.TLSConfig)) + go s.handleClient(tls.Server(conn, s.config.TLSConfig), true) } } -func (s *Server) handleClient(conn net.Conn) { +func (s *Server) handleClient(conn net.Conn, main bool) { logger := log.NewContext(s.logger).With("addr", conn.RemoteAddr()) logger.Log( @@ -178,46 +181,77 @@ func (s *Server) handleClient(conn net.Conn) { ) var ( - identifier id.ID - req *http.Request - resp *http.Response - tunnels map[string]*proto.Tunnel err error - ok bool - inConnPool bool + ok bool + ID id.ID + cfg *RegisteredClientConfig + clientInfo *RegisteredClientInfo + controller *clientConnectionController + body io.Reader + req *http.Request + resp *http.Response + tlsConn *tls.Conn + ccon *clientConnection ) - tlsConn, ok := conn.(*tls.Conn) - if !ok { + if tlsConn, ok = conn.(*tls.Conn); !ok { logger.Log( "level", 0, "msg", "invalid connection type", - "err", fmt.Errorf("expected TLS conn, got %T", conn), + "err", fmt.Errorf("expected TLS controller, got %T", conn), ) - goto reject + goto Reject } - identifier, err = id.PeerID(tlsConn) + ID, err = id.PeerID(tlsConn) if err != nil { logger.Log( "level", 2, "msg", "certificate error", "err", err, ) - goto reject + goto Reject } - logger = logger.With("identifier", identifier) + logger = logger.With("identifier", ID) if s.config.AutoSubscribe { - s.Subscribe(identifier) - } else if !s.IsSubscribed(identifier) { + s.Subscribe(ID) + } else if s.config.RegisteredClientsProvider != nil { + if cfg, err = s.config.RegisteredClientsProvider.Get(ID); err == nil { + if s.connPool.Has(ID) { + if ccon, err = s.connPool.AddClientConnection(ID, conn); err != nil { + goto Reject + } + logger.Log( + "level", 1, + "msg", fmt.Sprintf("new connection #%d for client added", ccon.id), + ) + return + } else { + s.Subscribe(ID) + } + } else if IsNotRegistered(err) { + logger.Log( + "level", 2, + "msg", err.Error(), + ) + goto Reject + } else { + logger.Log( + "level", 2, + "msg", "Get registered controller failed", + "err", err, + ) + goto Reject + } + } else if !s.IsSubscribed(ID) { logger.Log( "level", 2, - "msg", "unknown client", + "msg", "unknown controller", ) - goto reject + goto Reject } if err = conn.SetDeadline(time.Time{}); err != nil { @@ -226,27 +260,34 @@ func (s *Server) handleClient(conn net.Conn) { "msg", "setting infinite deadline failed", "err", err, ) - goto reject + goto Reject } - if err := s.connPool.AddConn(conn, identifier); err != nil { + controller = &clientConnectionController{cfg, conn, ID, logger} + + if err := s.connPool.AddConn(controller); err != nil { logger.Log( "level", 2, "msg", "adding connection failed", "err", err, ) - goto reject + goto Reject } + inConnPool = true - req, err = http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil) + req, err = http.NewRequest(http.MethodConnect, s.connPool.URL(controller.ID), body) if err != nil { logger.Log( "level", 2, "msg", "handshake request creation failed", "err", err, ) - goto reject + goto Reject + } + + if cfg != nil && cfg.Connections > 0 { + req.Header.Set(HeaderConnectionsCount, strconv.Itoa(int(controller.cfg.Connections))) } { @@ -262,7 +303,7 @@ func (s *Server) handleClient(conn net.Conn) { "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } if resp.StatusCode != http.StatusOK { @@ -272,7 +313,7 @@ func (s *Server) handleClient(conn net.Conn) { "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } if resp.ContentLength == 0 { @@ -282,35 +323,67 @@ func (s *Server) handleClient(conn net.Conn) { "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } - if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(&tunnels); err != nil { - logger.Log( - "level", 2, - "msg", "handshake failed", - "err", err, - ) - goto reject + if cfg == nil { + if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(cfg); err != nil { + logger.Log( + "level", 2, + "msg", "handshake failed", + "err", err, + ) + goto Reject + } + // Ony main connection + cfg.Connections = 0 + } else { + clientInfo = &RegisteredClientInfo{} + if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(clientInfo); err != nil { + logger.Log( + "level", 2, + "msg", "handshake failed", + "err", err, + ) + goto Reject + } + + cfg.Hostname = clientInfo.Hostname + for name, ct := range clientInfo.Tunnels { + if t, ok := cfg.Tunnels[name]; ok { + if ct.Disabled { + t.Disabled = true + } else { + t.LocalAddr = ct.LocalAddr + t.Auth = ct.Auth + } + } else { + logger.Log( + "level", 1, + "msg", fmt.Sprintf("tunnel %q not registered", name), + ) + delete(clientInfo.Tunnels, name) + } + } } - if len(tunnels) == 0 { + if len(cfg.Tunnels) == 0 { err = fmt.Errorf("No tunnels") logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } - if err = s.addTunnels(tunnels, identifier); err != nil { + if err = s.addTunnels(cfg.Tunnels, ID); err != nil { logger.Log( "level", 2, "msg", "handshake failed", "err", err, ) - goto reject + goto Reject } logger.Log( @@ -319,22 +392,21 @@ func (s *Server) handleClient(conn net.Conn) { ) return - -reject: +Reject: logger.Log( "level", 1, "action", "rejected", ) if inConnPool { - s.notifyError(err, identifier) - s.connPool.DeleteConn(identifier) + s.notifyError(err, ID) + s.connPool.DeleteConn(ID) } conn.Close() } -// notifyError tries to send error to client. +// notifyError tries to send error to controller. func (s *Server) notifyError(serverError error, identifier id.ID) { if serverError == nil { return @@ -344,7 +416,7 @@ func (s *Server) notifyError(serverError error, identifier id.ID) { if err != nil { s.logger.Log( "level", 2, - "action", "client error notification failed", + "action", "controller error notification failed", "identifier", identifier, "err", err, ) @@ -364,17 +436,21 @@ func (s *Server) notifyError(serverError error, identifier id.ID) { func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) error { i := &RegistryItem{ Hosts: []*HostAuth{}, - Listeners: []net.Listener{}, + Listeners: []*Listener{}, } var err error for name, t := range tunnels { + if t.Disabled { + continue + } + switch t.Protocol { case proto.HTTP: - i.Hosts = append(i.Hosts, &HostAuth{t.Host, NewAuth(t.Auth)}) + i.Hosts = append(i.Hosts, &HostAuth{t, NewAuth(t.Auth)}) case proto.TCP, proto.TCP4, proto.TCP6, proto.UNIX: var l net.Listener - l, err = net.Listen(t.Protocol, t.Addr) + l, err = net.Listen(t.Protocol, t.RemoteAddr) if err != nil { goto rollback } @@ -386,7 +462,7 @@ func (s *Server) addTunnels(tunnels map[string]*proto.Tunnel, identifier id.ID) "addr", l.Addr(), ) - i.Listeners = append(i.Listeners, l) + i.Listeners = append(i.Listeners, &Listener{l, t}) default: err = fmt.Errorf("unsupported protocol for tunnel %s: %s", name, t.Protocol) goto rollback @@ -412,7 +488,7 @@ rollback: return err } -// Unsubscribe removes client from registry, disconnects client if already +// Unsubscribe removes controller from registry, disconnects controller if already // connected and returns it's RegistryItem. func (s *Server) Unsubscribe(identifier id.ID) *RegistryItem { s.connPool.DeleteConn(identifier) @@ -424,7 +500,7 @@ func (s *Server) Ping(identifier id.ID) (time.Duration, error) { return s.connPool.Ping(identifier) } -func (s *Server) listen(l net.Listener, identifier id.ID) { +func (s *Server) listen(l *Listener, identifier id.ID) { addr := l.Addr().String() for { @@ -451,6 +527,7 @@ func (s *Server) listen(l net.Listener, identifier id.ID) { } msg := &proto.ControlMessage{ + LocalAddr: l.Tunnel.LocalAddr, Action: proto.ActionProxy, ForwardedHost: l.Addr().String(), ForwardedProto: l.Addr().Network(), @@ -480,7 +557,7 @@ func (s *Server) listen(l net.Listener, identifier id.ID) { } } -// ServeHTTP proxies http connection to the client. +// ServeHTTP proxies http connection to the controller. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { resp, err := s.RoundTrip(r) if err == errUnauthorised { @@ -507,7 +584,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.WriteHeader(resp.StatusCode) transfer(w, resp.Body, log.NewContext(s.logger).With( - "dir", "client to user", + "dir", "controller to user", "dst", r.RemoteAddr, "src", r.Host, )) @@ -515,7 +592,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // RoundTrip is http.RoundTriper implementation. func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { - identifier, auth, ok := s.Subscriber(r.Host) + identifier, hostInfo, ok := s.Subscriber(r.Host) if !ok { return nil, errClientNotSubscribed } @@ -526,9 +603,9 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { } outr.Header = cloneHeader(r.Header) - if auth != nil { + if hostInfo.auth != nil { user, password, _ := r.BasicAuth() - if auth.User != user || auth.Password != password { + if hostInfo.auth.User != user || hostInfo.auth.Password != password { return nil, errUnauthorised } outr.Header.Del("Authorization") @@ -550,6 +627,7 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { } msg := &proto.ControlMessage{ + LocalAddr: hostInfo.tunnel.LocalAddr, Action: proto.ActionProxy, ForwardedHost: r.Host, ForwardedProto: scheme, @@ -561,7 +639,7 @@ func (s *Server) RoundTrip(r *http.Request) (*http.Response, error) { func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMessage) error { s.logger.Log( "level", 2, - "action", "proxy conn", + "action", "proxy controller", "identifier", identifier, "ctrlMsg", msg, ) @@ -583,7 +661,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe done := make(chan struct{}) go func() { transfer(pw, conn, log.NewContext(s.logger).With( - "dir", "user to client", + "dir", "user to controller", "dst", identifier, "src", conn.RemoteAddr(), )) @@ -598,7 +676,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe defer resp.Body.Close() transfer(conn, resp.Body, log.NewContext(s.logger).With( - "dir", "client to user", + "dir", "controller to user", "dst", conn.RemoteAddr(), "src", identifier, )) @@ -607,7 +685,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe s.logger.Log( "level", 2, - "action", "proxy conn done", + "action", "proxy controller done", "identifier", identifier, "ctrlMsg", msg, ) @@ -650,7 +728,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control "action", "transferred", "identifier", identifier, "bytes", cw.count, - "dir", "user to client", + "dir", "user to controller", "dst", r.Host, "src", r.RemoteAddr, ) @@ -676,7 +754,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control return resp, nil } -// connectRequest creates HTTP request to client with a given identifier having +// connectRequest creates HTTP request to controller with a given identifier having // control message and data input stream, output data stream results from // response the created request. func (s *Server) connectRequest(identifier id.ID, msg *proto.ControlMessage, r io.Reader) (*http.Request, error) { diff --git a/tcpproxy.go b/tcpproxy.go index 613de9b..1820aee 100644 --- a/tcpproxy.go +++ b/tcpproxy.go @@ -68,7 +68,7 @@ func (p *TCPProxy) Proxy(w io.Writer, r io.ReadCloser, msg *proto.ControlMessage return } - target := p.localAddrFor(msg.ForwardedHost) + target := p.localAddrFor(msg.LocalAddr) if target == "" { p.logger.Log( "level", 1, From 0cccde7a29fbeb47aeaab61700f28c1184527e0d Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Thu, 11 Oct 2018 17:11:12 -0300 Subject: [PATCH 02/12] Update README and config --- README.md | 109 ++++++++++++++++++++++++++++++++++++----- cli/tunnel/config.go | 2 +- cli/tunneld/options.go | 2 +- client.go | 7 +-- proto/tunnel.go | 2 +- 5 files changed, 104 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index f11dcfc..99f8fad 100644 --- a/README.md +++ b/README.md @@ -79,37 +79,122 @@ Sample configuration that exposes: looks like this ```yaml - server_addr: SERVER_IP:5223 - tunnels: - webui: - proto: http - addr: localhost:8080 - auth: user:password - host: webui.my-tunnel-host.com - ssh: - proto: tcp - addr: 192.168.0.5:22 - remote_addr: 0.0.0.0:22 +server_addr: SERVER_IP:5223 +tunnels: + webui: + proto: http + local_addr: localhost:8080 + auth: user:password + host: webui.my-tunnel-host.com + ssh: + proto: tcp + local_addr: 192.168.0.5:22 + remote_addr: 0.0.0.0:22 ``` Configuration options: +* `name`: set the client name (not required). +* `description`: set the client name (not required). * `server_addr`: server TCP address, i.e. `54.12.12.45:5223` * `tls_crt`: path to client TLS certificate, *default:* `client.crt` *in the config file directory* * `tls_key`: path to client TLS certificate key, *default:* `client.key` *in the config file directory* * `root_ca`: path to trusted root certificate authority pool file, if empty any server certificate is accepted * `tunnels / [name]` * `proto`: tunnel protocol, `http` or `tcp` - * `addr`: forward traffic to this local port number or network address, for `proto=http` this can be full URL i.e. `https://machine/sub/path/?plus=params`, supports URL schemes `http` and `https` + * `local_addr`: forward traffic to this local port number or network address, for `proto=http` this can be full URL i.e. `https://machine/sub/path/?plus=params`, supports URL schemes `http` and `https` * `auth`: (`proto=http`) (optional) basic authentication credentials to enforce on tunneled requests, format `user:password` * `host`: (`proto=http`) hostname to request (requires reserved name and DNS CNAME) * `remote_addr`: (`proto=tcp`) bind the remote TCP address + * `disabled`: if set to `true` this tunnel has be ignored. * `backoff` * `interval`: how long client would wait before redialing the server if connection was lost, exponential backoff initial interval, *default:* `500ms` * `multiplier`: interval multiplier if reconnect failed, *default:* `1.5` * `max_interval`: maximal time client would wait before redialing the server, *default:* `1m` * `max_time`: maximal time client would try to reconnect to the server if connection was lost, set `0` to never stop trying, *default:* `15m` +## Client Registration + +### Client + +Configuration file `tunnel.yaml`: + +```yaml +registered: true +server_addr: SERVER_IP:5223 +tunnels: + webui: + proto: http + local_addr: localhost:8080 + auth: user:password + ssh: + proto: tcp + local_addr: 192.168.0.5:22 +``` + +Configuration options: + +* `registered`: registered client. Set as `true`. +* `server_addr`: server TCP address, i.e. `54.12.12.45:5223` +* `tls_crt`: path to client TLS certificate, *default:* `client.crt` *in the config file directory* +* `tls_key`: path to client TLS certificate key, *default:* `client.key` *in the config file directory* +* `root_ca`: path to trusted root certificate authority pool file, if empty any server certificate is accepted +* `tunnels / [name]` + * `proto`: tunnel protocol, `http` or `tcp` + * `local_addr`: forward traffic to this local port number or network address, for `proto=http` this can be full URL i.e. `https://machine/sub/path/?plus=params`, supports URL schemes `http` and `https` + * `auth`: (`proto=http`) (optional) basic authentication credentials to enforce on tunneled requests, format `user:password` + * `disabled`: if set to `true` this tunnel has be ignored. +* `backoff` + * `interval`: how long client would wait before redialing the server if connection was lost, exponential backoff initial interval, *default:* `500ms` + * `multiplier`: interval multiplier if reconnect failed, *default:* `1.5` + * `max_interval`: maximal time client would wait before redialing the server, *default:* `1m` + * `max_time`: maximal time client would try to reconnect to the server if connection was lost, set `0` to never stop trying, *default:* `15m` + +### Server + +Create registered clients database structure if not exists: + +```bash +$ mkdir registered_clients +``` + +On client, get ID: `$ tunnel id` + +Configure client (file `registered_clients/CLIENT_ID.yaml`). +Example `registered_clients/SS2KPSV-5KG2URM-WYUEDLY-FBDAD7A-MUUVTDX-TL7KL45-2PQEQAD-IN4LVAH.yaml`: + +```yaml +connections: 4 +tunnels: + webui: + proto: http + host: mps.dea.ufv.br + ssh: + proto: tcp + remote_addr: 127.0.0.1:2222 +``` + +Configuration options: +* `disabled`: if set to `true` block this client. +* `connections`: number of connections in addition to the main connection. Default is `0`. +* `tunnels / [name]` + * `proto`: tunnel protocol, `http` or `tcp` + * `host`: (`proto=http`) hostname to request (requires reserved name and DNS CNAME) + * `remote_addr`: (`proto=tcp`) bind the remote TCP address + * `disabled`: if set to `true` this tunnel has be ignored. + +### Run + +Server: +```bash +$ tunneld -registeredClientsDB registered_clients +``` + +Client: +```bash +$ tunnel start-all +``` + ## How it works A client opens TLS connection to a server. The server accepts connections from known clients only. The client is recognized by its TLS certificate ID. The server is publicly available and proxies incoming connections to the client. Then the connection is further proxied in the client's network. diff --git a/cli/tunnel/config.go b/cli/tunnel/config.go index 67f3fb0..da13885 100644 --- a/cli/tunnel/config.go +++ b/cli/tunnel/config.go @@ -33,7 +33,7 @@ type BackoffConfig struct { // Tunnel defines a tunnel. type Tunnel struct { - Protocol string `yaml:"protocol,omitempty"` + Protocol string `yaml:"proto,omitempty"` LocalAddr string `yaml:"local_addr,omitempty"` Auth string `yaml:"auth,omitempty"` Host string `yaml:"host,omitempty"` diff --git a/cli/tunneld/options.go b/cli/tunneld/options.go index 331824a..39daf64 100644 --- a/cli/tunneld/options.go +++ b/cli/tunneld/options.go @@ -57,7 +57,7 @@ func parseArgs(args ...string) *options { tlsKey := flag.String("tlsKey", "server.key", "Path to a TLS key file") rootCA := flag.String("rootCA", "", "Path to the trusted certificate chian used for client certificate authentication, if empty any client certificate is accepted") clients := flag.String("clients", "", "Comma-separated list of tunnel client ids, if empty accept all clients") - clientsDir := flag.String("clientsDir", "", "Directory of registered clients") + clientsDir := flag.String("registeredClientsDB", "", "Database directory of registered clients") logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3") version := flag.Bool("version", false, "Prints tunneld version") flag.CommandLine.Parse(args) diff --git a/client.go b/client.go index 4033c11..24643e3 100644 --- a/client.go +++ b/client.go @@ -43,12 +43,13 @@ type ClientConfig struct { Proxy ProxyFunc // Logger is optional logger. If nil logging is disabled. Logger log.Logger - // Name the controller name + // Name the controller name. Name string - // Description Client Description + // Description Client Description. Description string - // Registered if client is registered on server + // Registered if client is registered on server. Registered bool + // Disabled if set to true block this client. } type ClientToServerConn struct { diff --git a/proto/tunnel.go b/proto/tunnel.go index b2e38dc..2d4c96a 100644 --- a/proto/tunnel.go +++ b/proto/tunnel.go @@ -12,7 +12,7 @@ type Tunnel struct { Disabled bool // Protocol specifies tunnel protocol, must be one of protocols known // by the server. - Protocol string + Protocol string `yaml:"proto"` // Host specified HTTP request host, it's required for HTTP and WS // tunnels. Host string From 57ab154a65d027b42f6cf97134f44d7830264bba Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Mon, 15 Oct 2018 10:28:44 -0300 Subject: [PATCH 03/12] Add registered client enabled/disabled check --- registered_client.go | 8 +++++++- server.go | 7 +++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/registered_client.go b/registered_client.go index bd36e17..66fef99 100644 --- a/registered_client.go +++ b/registered_client.go @@ -34,6 +34,7 @@ func IsNotRegistered(err error) (ok bool) { } type RegisteredClientConfig struct { + Disabled bool ID id.ID Name string Description string @@ -51,7 +52,8 @@ type RegisteredClientsFileSystemProvider struct { } func (p *RegisteredClientsFileSystemProvider) Get(clientID id.ID) (client *RegisteredClientConfig, err error) { - pth := filepath.Join(p.StorageDir, clientID.String()+".yaml") + base := filepath.Join(p.StorageDir, clientID.String()) + pth := filepath.Join(base, "config.yaml") var f io.ReadCloser if f, err = os.Open(pth); err != nil { if os.IsNotExist(err) { @@ -72,5 +74,9 @@ func (p *RegisteredClientsFileSystemProvider) Get(clientID id.ID) (client *Regis if client.Tunnels == nil { client.Tunnels = map[string]*proto.Tunnel{} } + + if _, err := os.Stat(filepath.Join(base, "disabled")); err == nil { + client.Disabled = true + } return } diff --git a/server.go b/server.go index a562871..fcf27a0 100644 --- a/server.go +++ b/server.go @@ -220,6 +220,13 @@ func (s *Server) handleClient(conn net.Conn, main bool) { s.Subscribe(ID) } else if s.config.RegisteredClientsProvider != nil { if cfg, err = s.config.RegisteredClientsProvider.Get(ID); err == nil { + if cfg.Disabled { + logger.Log( + "level", 2, + "msg", "Client has be disabled", + ) + goto Reject + } if s.connPool.Has(ID) { if ccon, err = s.connPool.AddClientConnection(ID, conn); err != nil { goto Reject From 053e2cb8021998c1ca49a86379ad83f4f36b25c7 Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Mon, 15 Oct 2018 10:33:54 -0300 Subject: [PATCH 04/12] Update README --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 99f8fad..481ddbe 100644 --- a/README.md +++ b/README.md @@ -160,8 +160,10 @@ $ mkdir registered_clients On client, get ID: `$ tunnel id` -Configure client (file `registered_clients/CLIENT_ID.yaml`). -Example `registered_clients/SS2KPSV-5KG2URM-WYUEDLY-FBDAD7A-MUUVTDX-TL7KL45-2PQEQAD-IN4LVAH.yaml`: +Registered **Client DIR** `registered_clients/CLIENT_ID`. + +Configure client (file `config.yaml`). +Example `registered_clients/SS2KPSV-5KG2URM-WYUEDLY-FBDAD7A-MUUVTDX-TL7KL45-2PQEQAD-IN4LVAH/config.yaml`: ```yaml connections: 4 @@ -175,7 +177,8 @@ tunnels: ``` Configuration options: -* `disabled`: if set to `true` block this client. +* `disabled`: if set to `true` block this client. Or create empty file `disabled` on client dir. +Example: `$ touch registered_clients/CLIENT_ID/disabled`. * `connections`: number of connections in addition to the main connection. Default is `0`. * `tunnels / [name]` * `proto`: tunnel protocol, `http` or `tcp` From 9b0eb8c40b6226e8d2fa37e52f3583cac54a2221 Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Mon, 15 Oct 2018 15:54:06 -0300 Subject: [PATCH 05/12] Fix registered clients multiple connections priority order. --- cli/tunnel/tunnel.go | 1 + client.go | 1 - pool.go | 64 ++++++++++++++++++++++++++++++++++++-------- registered_client.go | 2 +- server.go | 15 +++++++---- 5 files changed, 65 insertions(+), 18 deletions(-) diff --git a/cli/tunnel/tunnel.go b/cli/tunnel/tunnel.go index 9ebbbc3..1d16821 100644 --- a/cli/tunnel/tunnel.go +++ b/cli/tunnel/tunnel.go @@ -106,6 +106,7 @@ func MainArgs(args ...string) { Tunnels: tunnels(config.Tunnels), Proxy: proxy(config.Tunnels, logger), Logger: logger, + Registered: config.Registered, }) if err != nil { fatal("failed to create client: %s", err) diff --git a/client.go b/client.go index 24643e3..d2e206f 100644 --- a/client.go +++ b/client.go @@ -49,7 +49,6 @@ type ClientConfig struct { Description string // Registered if client is registered on server. Registered bool - // Disabled if set to true block this client. } type ClientToServerConn struct { diff --git a/pool.go b/pool.go index 219328b..0009569 100644 --- a/pool.go +++ b/pool.go @@ -7,8 +7,10 @@ package tunnel import ( "context" "fmt" + "io" "net" "net/http" + "sort" "sync" "time" @@ -17,19 +19,31 @@ import ( "github.com/mmatczuk/go-http-tunnel/id" ) +const reqPool = "req.pool" + type clientConnections struct { first *clientConnection last *clientConnection count int mu sync.Mutex lastid int + conns []*clientConnection +} + +func (cs *clientConnections) low() *clientConnection { + cs.mu.Lock() + defer cs.mu.Unlock() + sort.Slice(cs.conns, func(i, j int) bool { + return cs.conns[i].count < cs.conns[j].count + }) + return cs.conns[0].increase() } func (cs *clientConnections) add(con *http2.ClientConn) *clientConnection { cs.mu.Lock() defer cs.mu.Unlock() - c := &clientConnection{con, cs.last, nil, cs.lastid} + c := &clientConnection{con, cs.last, nil, cs.lastid, 0} cs.lastid++ if cs.last == nil { cs.first = c @@ -38,6 +52,7 @@ func (cs *clientConnections) add(con *http2.ClientConn) *clientConnection { } cs.last = c cs.count++ + cs.conns = append(cs.conns, c) return cs.last } @@ -65,11 +80,26 @@ func (cs *clientConnections) remove(c *clientConnection) { cs.count-- } +type clientConnectionSetter struct { + conn *clientConnection +} + type clientConnection struct { *http2.ClientConn - prev *clientConnection - next *clientConnection - id int + prev *clientConnection + next *clientConnection + id int + count int +} + +func (c *clientConnection) decrease() *clientConnection { + c.count-- + return c +} + +func (c *clientConnection) increase() *clientConnection { + c.count++ + return c } type connPair struct { @@ -83,12 +113,9 @@ type connPair struct { func (cp *connPair) next() *clientConnection { cp.mu.Lock() defer cp.mu.Unlock() - if cp.current == nil || cp.current.next == nil { - cp.current = cp.conns.first - } else { - cp.current = cp.current.next - } - return cp.current + con := cp.conns.low() + cp.current = con + return con } type connPool struct { @@ -105,7 +132,14 @@ func newConnPool(t *http2.Transport, f func(identifier id.ID)) *connPool { conns: make(map[string]*connPair), } } - +func (p *connPool) newRequest(method string, identifier id.ID, body io.Reader) (*http.Request, error) { + req, err := http.NewRequest(method, p.URL(identifier), body) + if err != nil { + return nil, err + } + req = req.WithContext(context.WithValue(req.Context(), reqPool, &clientConnectionSetter{})) + return req, err +} func (p *connPool) URL(identifier id.ID) string { return fmt.Sprint("https://", identifier) } @@ -116,6 +150,8 @@ func (p *connPool) GetClientConn(req *http.Request, addr string) (*http2.ClientC if cp, ok := p.conns[addr]; ok { conn := cp.next() + setter := req.Context().Value(reqPool) + setter.(*clientConnectionSetter).conn = conn cp.controller.logger.Log("level", 3, "client_conn", fmt.Sprintf("#%d", conn.id), "addr", addr) return conn.ClientConn, nil } @@ -197,6 +233,12 @@ func (p *connPool) DeleteConn(identifier id.ID) { } } +func (p *connPool) closeReqConn(req *http.Request) { + if conn := req.Context().Value(reqPool); conn != nil { + conn.(*clientConnectionSetter).conn.decrease() + } +} + func (p *connPool) Ping(identifier id.ID) (time.Duration, error) { p.mu.Lock() defer p.mu.Unlock() diff --git a/registered_client.go b/registered_client.go index 66fef99..c66e35c 100644 --- a/registered_client.go +++ b/registered_client.go @@ -57,7 +57,7 @@ func (p *RegisteredClientsFileSystemProvider) Get(clientID id.ID) (client *Regis var f io.ReadCloser if f, err = os.Open(pth); err != nil { if os.IsNotExist(err) { - return nil, nil + return nil, &ErrClientNotRegistered{clientID} } return nil, err } diff --git a/server.go b/server.go index fcf27a0..6794510 100644 --- a/server.go +++ b/server.go @@ -283,7 +283,7 @@ func (s *Server) handleClient(conn net.Conn, main bool) { inConnPool = true - req, err = http.NewRequest(http.MethodConnect, s.connPool.URL(controller.ID), body) + req, err = s.connPool.newRequest(http.MethodConnect, controller.ID, body) if err != nil { logger.Log( "level", 2, @@ -298,12 +298,14 @@ func (s *Server) handleClient(conn net.Conn, main bool) { } { - ctx, cancel := context.WithTimeout(context.Background(), DefaultTimeout) + ctx, cancel := context.WithTimeout(req.Context(), DefaultTimeout) defer cancel() req = req.WithContext(ctx) } resp, err = s.httpClient.Do(req) + defer s.connPool.closeReqConn(req) + if err != nil { logger.Log( "level", 2, @@ -419,7 +421,7 @@ func (s *Server) notifyError(serverError error, identifier id.ID) { return } - req, err := http.NewRequest(http.MethodConnect, s.connPool.URL(identifier), nil) + req, err := s.connPool.newRequest(http.MethodConnect, identifier, nil) if err != nil { s.logger.Log( "level", 2, @@ -661,8 +663,9 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe if err != nil { return err } + defer s.connPool.closeReqConn(req) - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(req.Context()) req = req.WithContext(ctx) done := make(chan struct{}) @@ -680,6 +683,7 @@ func (s *Server) proxyConn(identifier id.ID, conn net.Conn, msg *proto.ControlMe if err != nil { return fmt.Errorf("io error: %s", err) } + defer resp.Body.Close() transfer(conn, resp.Body, log.NewContext(s.logger).With( @@ -716,6 +720,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control if err != nil { return nil, fmt.Errorf("proxy request error: %s", err) } + defer s.connPool.closeReqConn(req) go func() { cw := &countWriter{pw, 0} @@ -765,7 +770,7 @@ func (s *Server) proxyHTTP(identifier id.ID, r *http.Request, msg *proto.Control // control message and data input stream, output data stream results from // response the created request. func (s *Server) connectRequest(identifier id.ID, msg *proto.ControlMessage, r io.Reader) (*http.Request, error) { - req, err := http.NewRequest(http.MethodPut, s.connPool.URL(identifier), r) + req, err := s.connPool.newRequest(http.MethodPut, identifier, r) if err != nil { return nil, fmt.Errorf("could not create request: %s", err) } From e09981d3c3c2d284821e79a23c52532b64c6fc41 Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Mon, 15 Oct 2018 19:02:30 -0300 Subject: [PATCH 06/12] Add registered clients integration test --- integration_test.go | 113 +++++++++++++++++++++++++++++++++---- pool.go | 5 +- registered_clients_test.go | 14 +++++ server.go | 1 + 4 files changed, 121 insertions(+), 12 deletions(-) create mode 100644 registered_clients_test.go diff --git a/integration_test.go b/integration_test.go index 3a9a3da..38e8ba0 100644 --- a/integration_test.go +++ b/integration_test.go @@ -89,12 +89,13 @@ func makeEcho(t testing.TB) (http net.Listener, tcp net.Listener) { return } -func makeTunnelServer(t testing.TB) *tunnel.Server { +func makeTunnelServer(t testing.TB, clientsProvider tunnel.RegisteredClientsProvider) *tunnel.Server { s, err := tunnel.NewServer(&tunnel.ServerConfig{ Addr: ":0", - AutoSubscribe: true, + AutoSubscribe: clientsProvider == nil, TLSConfig: tlsConfig(), Logger: log.NewStdLogger(), + RegisteredClientsProvider: clientsProvider, }) if err != nil { t.Fatal(err) @@ -118,14 +119,14 @@ func makeTunnelClient(t testing.TB, serverAddr string, httpLocalAddr, httpAddr, tunnels := map[string]*proto.Tunnel{ proto.HTTP: { - Protocol: proto.HTTP, - Host: "localhost", - Auth: "user:password", + Protocol: proto.HTTP, + Host: "localhost", + Auth: "user:password", LocalAddr: httpLocalAddr.String(), }, proto.TCP: { - Protocol: proto.TCP, - LocalAddr: tcpLocalAddr.String(), + Protocol: proto.TCP, + LocalAddr: tcpLocalAddr.String(), }, } @@ -158,7 +159,7 @@ func TestIntegration(t *testing.T) { defer tcp.Close() // server - s := makeTunnelServer(t) + s := makeTunnelServer(t, nil) defer s.Stop() h := httptest.NewServer(s) defer h.Close() @@ -197,7 +198,7 @@ func TestIntegration(t *testing.T) { }() wg.Add(1) go func() { - testTCP(t, tcpLocalAddr, p, r) + testTCP(t, tcpLocalAddr, p, r, false) wg.Done() }() } @@ -205,6 +206,94 @@ func TestIntegration(t *testing.T) { wg.Wait() } +func makeTunnelRegisteredClient(t testing.TB, serverAddr string, tcpLocalAddr net.Addr) *tunnel.Client { + tcpProxy := tunnel.NewTCPProxy(tcpLocalAddr.String(), log.NewStdLogger()) + + tunnels := map[string]*proto.Tunnel{ + proto.TCP: { + Protocol: proto.TCP, + LocalAddr: tcpLocalAddr.String(), + }, + } + + c, err := tunnel.NewClient(&tunnel.ClientConfig{ + ServerAddr: serverAddr, + TLSClientConfig: tlsConfig(), + Tunnels: tunnels, + Registered: true, + Proxy: tunnel.Proxy(tunnel.ProxyFuncs{ + TCP: tcpProxy.Proxy, + }), + Logger: log.NewStdLogger(), + }) + if err != nil { + t.Fatal(err) + } + go func() { + if err := c.Start(); err != nil { + t.Log(err) + } + }() + + return c +} + +func TestRegisteredClientIntegration(t *testing.T) { + // local services + tcp, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + go echoTCP(tcp) + defer tcp.Close() + + tcpRemoteAddr := freeAddr() + numConnections := 3 + + // server + s := makeTunnelServer(t, ®isteredClientsProvider{&tunnel.RegisteredClientConfig{ + Connections: numConnections, + Tunnels: map[string]*proto.Tunnel{ + proto.TCP: { + Protocol: proto.TCP, + RemoteAddr: tcpRemoteAddr.String(), + }, + }, + }}) + defer s.Stop() + + // controller + c := makeTunnelRegisteredClient(t, s.Addr(), tcp.Addr()) + // FIXME: replace sleep with controller state change watch when ready + time.Sleep(500 * time.Millisecond) + defer c.Stop() + + payload := randPayload(payloadInitialSize, payloadLen) + table := []struct { + S []uint + }{ + {[]uint{200, 160, 120, 80, 40, 20}}, + {[]uint{40, 80, 120, 160, 200}}, + {[]uint{0, 0, 0, 0, 0, 200}}, + } + + var wg sync.WaitGroup + for _, test := range table { + for i, repeat := range test.S { + p := payload[i] + r := repeat + for i := 0; i <= numConnections; i++ { + wg.Add(1) + go func() { + testTCP(t, tcpRemoteAddr, p, r, true) + wg.Done() + }() + } + } + } + wg.Wait() +} + func testHTTP(t testing.TB, addr net.Addr, payload []byte, repeat uint) { url := fmt.Sprintf("http://localhost:%s/some/path", port(addr)) @@ -234,13 +323,17 @@ func testHTTP(t testing.TB, addr net.Addr, payload []byte, repeat uint) { } } -func testTCP(t testing.TB, addr net.Addr, payload []byte, repeat uint) { +func testTCP(t testing.TB, addr net.Addr, payload []byte, repeat uint, sleep bool) { conn, err := net.Dial("tcp", addr.String()) if err != nil { t.Fatal("Dial failed", err) } defer conn.Close() + if sleep { + time.Sleep(50 * time.Millisecond) + } + var buf = make([]byte, 10*1024*1024) var read, write int for repeat > 0 { diff --git a/pool.go b/pool.go index 0009569..9096424 100644 --- a/pool.go +++ b/pool.go @@ -150,8 +150,9 @@ func (p *connPool) GetClientConn(req *http.Request, addr string) (*http2.ClientC if cp, ok := p.conns[addr]; ok { conn := cp.next() - setter := req.Context().Value(reqPool) - setter.(*clientConnectionSetter).conn = conn + if setter := req.Context().Value(reqPool); setter != nil { + setter.(*clientConnectionSetter).conn = conn + } cp.controller.logger.Log("level", 3, "client_conn", fmt.Sprintf("#%d", conn.id), "addr", addr) return conn.ClientConn, nil } diff --git a/registered_clients_test.go b/registered_clients_test.go new file mode 100644 index 0000000..3da24b0 --- /dev/null +++ b/registered_clients_test.go @@ -0,0 +1,14 @@ +package tunnel_test + +import ( + "github.com/mmatczuk/go-http-tunnel" + "github.com/mmatczuk/go-http-tunnel/id" +) + +type registeredClientsProvider struct { + cfg *tunnel.RegisteredClientConfig +} + +func (p registeredClientsProvider) Get(clientID id.ID) (client *tunnel.RegisteredClientConfig, err error) { + return p.cfg, nil +} diff --git a/server.go b/server.go index 6794510..ffbc2a2 100644 --- a/server.go +++ b/server.go @@ -336,6 +336,7 @@ func (s *Server) handleClient(conn net.Conn, main bool) { } if cfg == nil { + cfg = &RegisteredClientConfig{} if err = json.NewDecoder(&io.LimitedReader{R: resp.Body, N: 126976}).Decode(cfg); err != nil { logger.Log( "level", 2, From e417dd36e8a16c07b498f0d6ceb71020480862f2 Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Tue, 16 Oct 2018 08:25:31 -0300 Subject: [PATCH 07/12] Rename connection constant key of request context --- pool.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pool.go b/pool.go index 9096424..ce50fb7 100644 --- a/pool.go +++ b/pool.go @@ -19,7 +19,7 @@ import ( "github.com/mmatczuk/go-http-tunnel/id" ) -const reqPool = "req.pool" +const connKey = "conn" type clientConnections struct { first *clientConnection @@ -137,7 +137,7 @@ func (p *connPool) newRequest(method string, identifier id.ID, body io.Reader) ( if err != nil { return nil, err } - req = req.WithContext(context.WithValue(req.Context(), reqPool, &clientConnectionSetter{})) + req = req.WithContext(context.WithValue(req.Context(), connKey, &clientConnectionSetter{})) return req, err } func (p *connPool) URL(identifier id.ID) string { @@ -150,7 +150,7 @@ func (p *connPool) GetClientConn(req *http.Request, addr string) (*http2.ClientC if cp, ok := p.conns[addr]; ok { conn := cp.next() - if setter := req.Context().Value(reqPool); setter != nil { + if setter := req.Context().Value(connKey); setter != nil { setter.(*clientConnectionSetter).conn = conn } cp.controller.logger.Log("level", 3, "client_conn", fmt.Sprintf("#%d", conn.id), "addr", addr) @@ -235,7 +235,7 @@ func (p *connPool) DeleteConn(identifier id.ID) { } func (p *connPool) closeReqConn(req *http.Request) { - if conn := req.Context().Value(reqPool); conn != nil { + if conn := req.Context().Value(connKey); conn != nil { conn.(*clientConnectionSetter).conn.decrease() } } From 863bcb37849fe673b2522d49d508bb8a2f9801e7 Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Tue, 16 Oct 2018 08:27:19 -0300 Subject: [PATCH 08/12] Add header comments --- cmd/tunnel/main.go | 4 ++++ cmd/tunneld/main.go | 4 ++++ registered_client.go | 4 ++++ registered_clients_test.go | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/cmd/tunnel/main.go b/cmd/tunnel/main.go index 21b02a2..1f50381 100644 --- a/cmd/tunnel/main.go +++ b/cmd/tunnel/main.go @@ -1,3 +1,7 @@ +// Copyright (C) 2017 Michał Matczuk +// Use of this source code is governed by an AGPL-style +// license that can be found in the LICENSE file. + package main import "github.com/mmatczuk/go-http-tunnel/cli/tunnel" diff --git a/cmd/tunneld/main.go b/cmd/tunneld/main.go index 119a380..0e12e3c 100644 --- a/cmd/tunneld/main.go +++ b/cmd/tunneld/main.go @@ -1,3 +1,7 @@ +// Copyright (C) 2017 Michał Matczuk +// Use of this source code is governed by an AGPL-style +// license that can be found in the LICENSE file. + package main import "github.com/mmatczuk/go-http-tunnel/cli/tunneld" diff --git a/registered_client.go b/registered_client.go index c66e35c..f28cfd4 100644 --- a/registered_client.go +++ b/registered_client.go @@ -1,3 +1,7 @@ +// Copyright (C) 2017 Michał Matczuk +// Use of this source code is governed by an AGPL-style +// license that can be found in the LICENSE file. + package tunnel import ( diff --git a/registered_clients_test.go b/registered_clients_test.go index 3da24b0..7ca95dc 100644 --- a/registered_clients_test.go +++ b/registered_clients_test.go @@ -1,3 +1,7 @@ +// Copyright (C) 2017 Michał Matczuk +// Use of this source code is governed by an AGPL-style +// license that can be found in the LICENSE file. + package tunnel_test import ( From d69513b0ffa7f16ea8f17677e9b56798feb166cf Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Wed, 17 Oct 2018 08:47:44 -0300 Subject: [PATCH 09/12] Exposes "tunnel" and "tunneld" to be embedded --- cli/tunnel/config.go | 2 +- cli/tunnel/options.go | 56 ++++++++++++++++++++++++++++-------------- cli/tunnel/tunnel.go | 23 ++++++++++++----- cli/tunnel/version.go | 5 ++++ cli/tunneld/options.go | 6 +++-- cli/tunneld/tunneld.go | 4 +-- 6 files changed, 67 insertions(+), 29 deletions(-) diff --git a/cli/tunnel/config.go b/cli/tunnel/config.go index da13885..3176a1e 100644 --- a/cli/tunnel/config.go +++ b/cli/tunnel/config.go @@ -51,7 +51,7 @@ type ClientConfig struct { Tunnels map[string]*Tunnel `yaml:"tunnels"` } -func loadClientConfigFromFile(file string) (*ClientConfig, error) { +func LoadClientConfigFromFile(file string) (*ClientConfig, error) { buf, err := ioutil.ReadFile(file) if err != nil { return nil, fmt.Errorf("failed to read file %q: %s", file, err) diff --git a/cli/tunnel/options.go b/cli/tunnel/options.go index b70290d..844c4b2 100644 --- a/cli/tunnel/options.go +++ b/cli/tunnel/options.go @@ -67,24 +67,44 @@ func init() { } type options struct { - config string - logLevel int - version bool - command string - args []string + config string + logLevel int + version bool + command string + args []string } -func parseArgs(args ...string) (*options, error) { - config := flag.String("config", "tunnel.yaml", "Path to tunnel configuration file") - logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3") - version := flag.Bool("version", false, "Prints tunnel version") - flag.CommandLine.Parse(args) +func (opt options) Command() string { + return opt.command +} + +func (opt options) LogLevel() int { + return opt.logLevel +} + +func (opt options) Args() []string { + return opt.args +} + +func ParseArgs(hasConfig bool, args ...string) (*options, error) { + var config *string + cli := flag.NewFlagSet(args[0], flag.ExitOnError) + args = args[1:] + if hasConfig { + config = cli.String("config", "tunnel.yaml", "Path to tunnel configuration file") + } else { + var s = "" + config = &s + } + logLevel := cli.Int("log-level", 1, "Level of messages to log, 0-3") + version := cli.Bool("version", false, "Prints tunnel version") + cli.Parse(args) opts := &options{ - config: *config, - logLevel: *logLevel, - version: *version, - command: flag.Arg(0), + config: *config, + logLevel: *logLevel, + version: *version, + command: cli.Arg(0), } if opts.version { @@ -93,20 +113,20 @@ func parseArgs(args ...string) (*options, error) { switch opts.command { case "": - flag.Usage() + cli.Usage() os.Exit(2) case "id", "list": - opts.args = flag.Args()[1:] + opts.args = cli.Args()[1:] if len(opts.args) > 0 { return nil, fmt.Errorf("list takes no arguments") } case "start": - opts.args = flag.Args()[1:] + opts.args = cli.Args()[1:] if len(opts.args) == 0 { return nil, fmt.Errorf("you must specify at least one tunnel to start") } case "start-all": - opts.args = flag.Args()[1:] + opts.args = cli.Args()[1:] if len(opts.args) > 0 { return nil, fmt.Errorf("start-all takes no arguments") } diff --git a/cli/tunnel/tunnel.go b/cli/tunnel/tunnel.go index 1d16821..f1706fb 100644 --- a/cli/tunnel/tunnel.go +++ b/cli/tunnel/tunnel.go @@ -28,25 +28,36 @@ func Main() { } func MainArgs(args ...string) { - opts, err := parseArgs(args...) + opts, err := ParseArgs(true, args...) if err != nil { fatal(err.Error()) } + MainOpts(opts) +} +func MainOpts(opts *options) { if opts.version { fmt.Println(version) return } - logger := log.NewFilterLogger(log.NewStdLogger(), opts.logLevel) - // read configuration file - config, err := loadClientConfigFromFile(opts.config) + config, err := LoadClientConfigFromFile(opts.config) if err != nil { fatal("configuration error: %s", err) } - switch opts.command { + MainConfigOptions(config, opts) +} + +func MainConfigOptions(config *ClientConfig, options *options) { + MainConfig(config, options.logLevel, options.command, options.args...) +} + +func MainConfig(config *ClientConfig, logLevel int, command string, args ...string) { + logger := log.NewFilterLogger(log.NewStdLogger(), logLevel) + + switch command { case "id": cert, err := tls.LoadX509KeyPair(config.TLSCrt, config.TLSKey) if err != nil { @@ -74,7 +85,7 @@ func MainArgs(args ...string) { return case "start": tunnels := make(map[string]*Tunnel) - for _, arg := range opts.args { + for _, arg := range args { t, ok := config.Tunnels[arg] if !ok { fatal("no such tunnel %q", arg) diff --git a/cli/tunnel/version.go b/cli/tunnel/version.go index 8f12ee6..28756d1 100644 --- a/cli/tunnel/version.go +++ b/cli/tunnel/version.go @@ -5,3 +5,8 @@ package tunnel var version = "snapshot" + +// Return the tunnel version +func GetVersion() string { + return version +} diff --git a/cli/tunneld/options.go b/cli/tunneld/options.go index 39daf64..bc7226a 100644 --- a/cli/tunneld/options.go +++ b/cli/tunneld/options.go @@ -49,7 +49,9 @@ type options struct { version bool } -func parseArgs(args ...string) *options { +func ParseArgs(args ...string) *options { + flag := flag.NewFlagSet(args[0], flag.ExitOnError) + args = args[1:] httpAddr := flag.String("httpAddr", ":80", "Public address for HTTP connections, empty string to disable") httpsAddr := flag.String("httpsAddr", ":443", "Public address listening for HTTPS connections, emptry string to disable") tunnelAddr := flag.String("tunnelAddr", ":5223", "Public address listening for tunnel client") @@ -60,7 +62,7 @@ func parseArgs(args ...string) *options { clientsDir := flag.String("registeredClientsDB", "", "Database directory of registered clients") logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3") version := flag.Bool("version", false, "Prints tunneld version") - flag.CommandLine.Parse(args) + flag.Parse(args) return &options{ httpAddr: *httpAddr, diff --git a/cli/tunneld/tunneld.go b/cli/tunneld/tunneld.go index c577bb6..53f0e9d 100644 --- a/cli/tunneld/tunneld.go +++ b/cli/tunneld/tunneld.go @@ -21,11 +21,11 @@ import ( ) func Main() { - MainArgs(os.Args[1:]...) + MainArgs(os.Args...) } func MainArgs(args ...string) { - opts := parseArgs(args...) + opts := ParseArgs(args...) if opts.version { fmt.Println(version) From f6315e542592db053fb59ff37341dd1ffbcaff44 Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Wed, 17 Oct 2018 09:06:33 -0300 Subject: [PATCH 10/12] Update README with Exposes "tunnel" and "tunneld" to be embedded --- README.md | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/README.md b/README.md index 481ddbe..212296f 100644 --- a/README.md +++ b/README.md @@ -198,6 +198,56 @@ Client: $ tunnel start-all ``` +### Embeded + +#### tunnel +Example for embed `tunnel` in your code. + +```go +package main + +import "github.com/mmatczuk/go-http-tunnel/cli/tunnel" + +func tunnelClient() { + options, err := tunnel.ParseArgs(false, "mycmd", "start-all") + if err != nil { + panic(err) + } + tunnel.MainConfigOptions(&tunnel.ClientConfig{ + ServerAddr: "domain.com:2000", + Tunnels: map[string]*tunnel.Tunnel{ + "main": { + Protocol: "tcp", + LocalAddr: "localhost:5000", + RemoteAddr: "domain.com:5000", + }, + }, + }, options) +} + +func main() { + go tunnelClient() + + // ... my system code +} +``` + +#### tunneld + +Example for embed `tunneld` in your code. + +```go +package main + +import "github.com/mmatczuk/go-http-tunnel/cli/tunneld" + +func main() { + go tunneld.MainArgs("mycmd", "-log-level", "3", "-httpAddr", ":19000", "-httpsAddr", ":19001") + + // ... my system code +} +``` + ## How it works A client opens TLS connection to a server. The server accepts connections from known clients only. The client is recognized by its TLS certificate ID. The server is publicly available and proxies incoming connections to the client. Then the connection is further proxied in the client's network. From 48d780c09bc47fa7ef54e483073f94f445aff37e Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Thu, 1 Nov 2018 19:01:42 -0200 Subject: [PATCH 11/12] Add read config from stdin --- cli/tunnel/config.go | 18 +++++++++++++++--- cli/tunnel/options.go | 3 ++- cli/tunnel/tunnel.go | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/cli/tunnel/config.go b/cli/tunnel/config.go index 3176a1e..45c0c77 100644 --- a/cli/tunnel/config.go +++ b/cli/tunnel/config.go @@ -7,6 +7,7 @@ package tunnel import ( "fmt" "io/ioutil" + "os" "path/filepath" "time" @@ -21,6 +22,7 @@ const ( DefaultBackoffMultiplier = 1.5 DefaultBackoffMaxInterval = 60 * time.Second DefaultBackoffMaxTime = 15 * time.Minute + ConfigFileSTDIN = "-" ) // BackoffConfig defines behavior of staggering reconnection retries. @@ -52,9 +54,19 @@ type ClientConfig struct { } func LoadClientConfigFromFile(file string) (*ClientConfig, error) { - buf, err := ioutil.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("failed to read file %q: %s", file, err) + var ( + buf []byte + err error + ) + if file == ConfigFileSTDIN { + if buf, err = ioutil.ReadAll(os.Stdin); err != nil { + return nil, fmt.Errorf("failed to read config from STDIN: ", err) + } + file = "STDIN" + } else { + if buf, err = ioutil.ReadFile(file); err != nil { + return nil, fmt.Errorf("failed to read file %q: %s", file, err) + } } c := ClientConfig{ diff --git a/cli/tunnel/options.go b/cli/tunnel/options.go index 844c4b2..46bde17 100644 --- a/cli/tunnel/options.go +++ b/cli/tunnel/options.go @@ -91,7 +91,8 @@ func ParseArgs(hasConfig bool, args ...string) (*options, error) { cli := flag.NewFlagSet(args[0], flag.ExitOnError) args = args[1:] if hasConfig { - config = cli.String("config", "tunnel.yaml", "Path to tunnel configuration file") + config = cli.String("config", "tunnel.yaml", + "Path to tunnel configuration file. Use HYPHEN (-) for read STDIN.") } else { var s = "" config = &s diff --git a/cli/tunnel/tunnel.go b/cli/tunnel/tunnel.go index f1706fb..99f4b97 100644 --- a/cli/tunnel/tunnel.go +++ b/cli/tunnel/tunnel.go @@ -24,7 +24,7 @@ import ( ) func Main() { - MainArgs(os.Args[1:]...) + MainArgs(os.Args...) } func MainArgs(args ...string) { From 883ecc8f468b6c0cbf89d265b958b9ed240739bd Mon Sep 17 00:00:00 2001 From: "Moises P. Sena" Date: Tue, 8 Jan 2019 18:01:51 -0200 Subject: [PATCH 12/12] Update Config, Client subscription and README --- README.md | 45 +++++++++++++++++++++++++++++++++++++----- cli/tunnel/config.go | 17 ++++++++++++++-- cli/tunnel/options.go | 10 +++++----- cli/tunneld/options.go | 2 +- cli/tunneld/tunneld.go | 34 +++++++++++++++---------------- 5 files changed, 77 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 212296f..106fdd4 100644 --- a/README.md +++ b/README.md @@ -155,15 +155,15 @@ Configuration options: Create registered clients database structure if not exists: ```bash -$ mkdir registered_clients +$ mkdir clients ``` On client, get ID: `$ tunnel id` -Registered **Client DIR** `registered_clients/CLIENT_ID`. +Registered **Client DIR** `clients/CLIENT_ID`. Configure client (file `config.yaml`). -Example `registered_clients/SS2KPSV-5KG2URM-WYUEDLY-FBDAD7A-MUUVTDX-TL7KL45-2PQEQAD-IN4LVAH/config.yaml`: +Example `clients/SS2KPSV-5KG2URM-WYUEDLY-FBDAD7A-MUUVTDX-TL7KL45-2PQEQAD-IN4LVAH/config.yaml`: ```yaml connections: 4 @@ -178,7 +178,7 @@ tunnels: Configuration options: * `disabled`: if set to `true` block this client. Or create empty file `disabled` on client dir. -Example: `$ touch registered_clients/CLIENT_ID/disabled`. +Example: `$ touch clients/CLIENT_ID/disabled`. * `connections`: number of connections in addition to the main connection. Default is `0`. * `tunnels / [name]` * `proto`: tunnel protocol, `http` or `tcp` @@ -190,7 +190,11 @@ Example: `$ touch registered_clients/CLIENT_ID/disabled`. Server: ```bash -$ tunneld -registeredClientsDB registered_clients +$ tunneld -clientsDir ./clients + +# or (using default clients dir, only if dir exists) + +$ tunneld ``` Client: @@ -198,6 +202,37 @@ Client: $ tunnel start-all ``` +#### With Supervisor + +Server example: + + [program:tunneld] + directory=/etc/tunneld + command=/usr/bin/tunneld + autostart=true + autorestart=true + stdout_logfile=/var/log/tunneld.log + stdout_logfile_maxbytes=5MB + stdout_logfile_backups=2 + redirect_stderr=true + user = root + environment = HOME="/root", USER="root" + +Client example: + + [program:tunnel] + directory=/etc/tunnel + command=/usr/bin/tunnel + autostart=true + autorestart=true + startsecs=10 + stdout_logfile=/var/log/tunnel.log + stdout_logfile_maxbytes=5MB + stdout_logfile_backups=2 + redirect_stderr=true + user = root + environment = HOME="/root", USER="root" + ### Embeded #### tunnel diff --git a/cli/tunnel/config.go b/cli/tunnel/config.go index 45c0c77..e710674 100644 --- a/cli/tunnel/config.go +++ b/cli/tunnel/config.go @@ -58,20 +58,25 @@ func LoadClientConfigFromFile(file string) (*ClientConfig, error) { buf []byte err error ) + + var configDir string + if file == ConfigFileSTDIN { if buf, err = ioutil.ReadAll(os.Stdin); err != nil { return nil, fmt.Errorf("failed to read config from STDIN: ", err) } file = "STDIN" + configDir = "." } else { + configDir = filepath.Dir(file) if buf, err = ioutil.ReadFile(file); err != nil { return nil, fmt.Errorf("failed to read file %q: %s", file, err) } } c := ClientConfig{ - TLSCrt: filepath.Join(filepath.Dir(file), "client.crt"), - TLSKey: filepath.Join(filepath.Dir(file), "client.key"), + TLSCrt: "client.crt", + TLSKey: "client.key", Backoff: BackoffConfig{ Interval: DefaultBackoffInterval, Multiplier: DefaultBackoffMultiplier, @@ -84,6 +89,14 @@ func LoadClientConfigFromFile(file string) (*ClientConfig, error) { return nil, fmt.Errorf("failed to parse file %q: %s", file, err) } + if filepath.Dir(c.TLSCrt) == "." { + c.TLSCrt = filepath.Join(configDir, c.TLSCrt) + } + + if filepath.Dir(c.TLSKey) == "." { + c.TLSKey = filepath.Join(configDir, c.TLSKey) + } + if c.ServerAddr == "" { return nil, fmt.Errorf("server_addr: missing") } diff --git a/cli/tunnel/options.go b/cli/tunnel/options.go index 46bde17..296d75d 100644 --- a/cli/tunnel/options.go +++ b/cli/tunnel/options.go @@ -86,7 +86,7 @@ func (opt options) Args() []string { return opt.args } -func ParseArgs(hasConfig bool, args ...string) (*options, error) { +func ParseArgs(hasConfig bool, args ...string) (opts *options, err error) { var config *string cli := flag.NewFlagSet(args[0], flag.ExitOnError) args = args[1:] @@ -99,9 +99,11 @@ func ParseArgs(hasConfig bool, args ...string) (*options, error) { } logLevel := cli.Int("log-level", 1, "Level of messages to log, 0-3") version := cli.Bool("version", false, "Prints tunnel version") - cli.Parse(args) + if err = cli.Parse(args); err != nil { + return nil, err + } - opts := &options{ + opts = &options{ config: *config, logLevel: *logLevel, version: *version, @@ -114,8 +116,6 @@ func ParseArgs(hasConfig bool, args ...string) (*options, error) { switch opts.command { case "": - cli.Usage() - os.Exit(2) case "id", "list": opts.args = cli.Args()[1:] if len(opts.args) > 0 { diff --git a/cli/tunneld/options.go b/cli/tunneld/options.go index bc7226a..353f7ec 100644 --- a/cli/tunneld/options.go +++ b/cli/tunneld/options.go @@ -59,7 +59,7 @@ func ParseArgs(args ...string) *options { tlsKey := flag.String("tlsKey", "server.key", "Path to a TLS key file") rootCA := flag.String("rootCA", "", "Path to the trusted certificate chian used for client certificate authentication, if empty any client certificate is accepted") clients := flag.String("clients", "", "Comma-separated list of tunnel client ids, if empty accept all clients") - clientsDir := flag.String("registeredClientsDB", "", "Database directory of registered clients") + clientsDir := flag.String("clientsDir", "clients", "Database directory of registered clients") logLevel := flag.Int("log-level", 1, "Level of messages to log, 0-3") version := flag.Bool("version", false, "Prints tunneld version") flag.Parse(args) diff --git a/cli/tunneld/tunneld.go b/cli/tunneld/tunneld.go index 53f0e9d..fc1daeb 100644 --- a/cli/tunneld/tunneld.go +++ b/cli/tunneld/tunneld.go @@ -41,13 +41,17 @@ func MainArgs(args ...string) { fatal("failed to configure tls: %s", err) } - autoSubscribe := opts.clients == "" && opts.clientsDir == "" - var clientsProvider tunnel.RegisteredClientsProvider if opts.clientsDir != "" { - clientsProvider = &tunnel.RegisteredClientsFileSystemProvider{opts.clientsDir} + if _, err := os.Stat(opts.clientsDir); err == nil { + clientsProvider = &tunnel.RegisteredClientsFileSystemProvider{opts.clientsDir} + } else { + opts.clientsDir = "" + } } + autoSubscribe := opts.clients == "" && opts.clientsDir == "" + // setup server server, err := tunnel.NewServer(&tunnel.ServerConfig{ Addr: opts.tunnelAddr, @@ -60,23 +64,17 @@ func MainArgs(args ...string) { fatal("failed to create server: %s", err) } - if !autoSubscribe { - if opts.clientsDir != "" { - if _, err := os.Stat(opts.clientsDir); err != nil { - fatal(err.Error()) + if !autoSubscribe && clientsProvider == nil { + for _, c := range strings.Split(opts.clients, ",") { + if c == "" { + fatal("empty client id") } - } else { - for _, c := range strings.Split(opts.clients, ",") { - if c == "" { - fatal("empty client id") - } - identifier := id.ID{} - err := identifier.UnmarshalText([]byte(c)) - if err != nil { - fatal("invalid identifier %q: %s", c, err) - } - server.Subscribe(identifier) + identifier := id.ID{} + err := identifier.UnmarshalText([]byte(c)) + if err != nil { + fatal("invalid identifier %q: %s", c, err) } + server.Subscribe(identifier) } }