Skip to content

Refactor logging to use stderr, auth key to stdout #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions cmd/wush/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
"os"
"strings"
"sync"
"time"

"github.com/mattn/go-isatty"
"github.com/prometheus/client_golang/prometheus"
"github.com/schollz/progressbar/v3"
"github.com/spf13/afero"
Expand Down Expand Up @@ -51,11 +51,14 @@ func serveCmd() *serpent.Command {
logSink = inv.Stderr
}
logger := slog.New(slog.NewTextHandler(logSink, nil))
hlog := func(format string, args ...any) {
fmt.Fprintf(inv.Stderr, format+"\n", args...)
}
dm, err := tsserver.DERPMapTailscale(ctx)
if err != nil {
return err
}
r := overlay.NewReceiveOverlay(logger, dm)
r := overlay.NewReceiveOverlay(logger, hlog, dm)

switch overlayType {
case "derp":
Expand All @@ -76,9 +79,15 @@ func serveCmd() *serpent.Command {
return fmt.Errorf("unknown overlay type: %s", overlayType)
}

fmt.Println("Your auth key is:")
fmt.Println("\t>", cliui.Code(r.ClientAuth().AuthKey()))
fmt.Println("Use this key to authenticate other", cliui.Code("wush"), "commands to this instance.")
// Ensure we always print the auth key on stdout
if isatty.IsTerminal(os.Stdout.Fd()) {
hlog("Your auth key is:")
fmt.Println("\t>", cliui.Code(r.ClientAuth().AuthKey()))
hlog("Use this key to authenticate other " + cliui.Code("wush") + " commands to this instance.")
} else {
fmt.Println(cliui.Code(r.ClientAuth().AuthKey()))
hlog("The auth key has been printed to stdout")
}

s, err := tsserver.NewServer(ctx, logger, r)
if err != nil {
Expand All @@ -95,7 +104,7 @@ func serveCmd() *serpent.Command {
ts.Up(ctx)
fs := afero.NewOsFs()

fmt.Println(cliui.Timestamp(time.Now()), "WireGuard is ready")
hlog("WireGuard is ready")

closers := []io.Closer{}

Expand All @@ -117,15 +126,16 @@ func serveCmd() *serpent.Command {
}
closers = append(closers, sshListener)

fmt.Println(cliui.Timestamp(time.Now()), "SSH server "+pretty.Sprint(cliui.DefaultStyles.Enabled, "enabled"))
// TODO: replace these logs with all of the options in the beginning.
hlog("SSH server " + pretty.Sprint(cliui.DefaultStyles.Enabled, "enabled"))
go func() {
err := sshSrv.Serve(sshListener)
if err != nil {
fmt.Println(cliui.Timestamp(time.Now()), "SSH server exited: "+err.Error())
hlog("SSH server exited: " + err.Error())
}
}()
} else {
fmt.Println(cliui.Timestamp(time.Now()), "SSH server "+pretty.Sprint(cliui.DefaultStyles.Disabled, "disabled"))
hlog("SSH server " + pretty.Sprint(cliui.DefaultStyles.Disabled, "disabled"))
}

if xslices.Contains(enabled, "cp") && !xslices.Contains(disabled, "cp") {
Expand All @@ -135,23 +145,23 @@ func serveCmd() *serpent.Command {
}
closers = append([]io.Closer{cpListener}, closers...)

fmt.Println(cliui.Timestamp(time.Now()), "File transfer server "+pretty.Sprint(cliui.DefaultStyles.Enabled, "enabled"))
hlog("File transfer server " + pretty.Sprint(cliui.DefaultStyles.Enabled, "enabled"))
go func() {
err := http.Serve(cpListener, http.HandlerFunc(cpHandler))
if err != nil {
fmt.Println(cliui.Timestamp(time.Now()), "File transfer server exited: "+err.Error())
hlog("File transfer server exited: " + err.Error())
}
}()
} else {
fmt.Println(cliui.Timestamp(time.Now()), "File transfer server "+pretty.Sprint(cliui.DefaultStyles.Disabled, "disabled"))
hlog("File transfer server " + pretty.Sprint(cliui.DefaultStyles.Disabled, "disabled"))
}

if xslices.Contains(enabled, "port-forward") && !xslices.Contains(disabled, "port-forward") {
ts.RegisterFallbackTCPHandler(func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
return func(src net.Conn) {
dst, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", dst.Port()))
if err != nil {
fmt.Println("failed to dial forwarded connection:", err.Error())
hlog(pretty.Sprint(cliui.DefaultStyles.Warn, "Failed to dial forwarded connection:", err.Error()))
src.Close()
return
}
Expand All @@ -160,7 +170,7 @@ func serveCmd() *serpent.Command {
}, true
})
} else {
fmt.Println(cliui.Timestamp(time.Now()), "Port-forward server "+pretty.Sprint(cliui.DefaultStyles.Disabled, "disabled"))
hlog("Port-forward server " + pretty.Sprint(cliui.DefaultStyles.Disabled, "disabled"))
}

ctx, ctxCancel := inv.SignalNotifyContext(ctx, os.Interrupt)
Expand Down
52 changes: 27 additions & 25 deletions overlay/receive.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,28 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/types/key"

"github.com/coder/pretty"
"github.com/coder/wush/cliui"
)

func NewReceiveOverlay(logger *slog.Logger, dm *tailcfg.DERPMap) *Receive {
type Logf func(format string, args ...any)

func NewReceiveOverlay(logger *slog.Logger, hlog Logf, dm *tailcfg.DERPMap) *Receive {
return &Receive{
Logger: logger,
DerpMap: dm,
SelfPriv: key.NewNode(),
PeerPriv: key.NewNode(),
in: make(chan *tailcfg.Node, 8),
out: make(chan *tailcfg.Node, 8),
Logger: logger,
HumanLogf: hlog,
DerpMap: dm,
SelfPriv: key.NewNode(),
PeerPriv: key.NewNode(),
in: make(chan *tailcfg.Node, 8),
out: make(chan *tailcfg.Node, 8),
}
}

type Receive struct {
Logger *slog.Logger
DerpMap *tailcfg.DERPMap
Logger *slog.Logger
HumanLogf Logf
DerpMap *tailcfg.DERPMap
// SelfPriv is the private key that peers will encrypt overlay messages to.
// The public key of this is sent in the auth key.
SelfPriv key.NodePrivate
Expand Down Expand Up @@ -83,10 +88,10 @@ func (r *Receive) PickDERPHome(ctx context.Context) error {
}

if report.PreferredDERP == 0 {
fmt.Println("Failed to determine overlay DERP region, defaulting to", cliui.Code("NYC"), ".")
r.HumanLogf("Failed to determine overlay DERP region, defaulting to %s.", cliui.Code("NYC"))
r.derpRegionID = 1
} else {
fmt.Println("Picked DERP region", cliui.Code(r.DerpMap.Regions[report.PreferredDERP].RegionName), "as overlay home")
r.HumanLogf("Picked DERP region %s as overlay home", cliui.Code(r.DerpMap.Regions[report.PreferredDERP].RegionName))
r.derpRegionID = uint16(report.PreferredDERP)
}

Expand Down Expand Up @@ -139,7 +144,7 @@ func (r *Receive) ListenOverlaySTUN(ctx context.Context) (<-chan struct{}, error
case <-restun.C:
_, err = conn.WriteToUDP(m.Raw, srvAddr)
if err != nil {
fmt.Println(cliui.Timestamp(time.Now()), "Failed to write STUN request on overlay:", err)
r.HumanLogf("%s Failed to write STUN request on overlay: %s", cliui.Timestamp(time.Now()), err)
}
restun.Reset(30 * time.Second)
}
Expand Down Expand Up @@ -169,7 +174,7 @@ func (r *Receive) ListenOverlaySTUN(ctx context.Context) (<-chan struct{}, error
peers.Range(func(_ key.NodePublic, addr netip.AddrPort) bool {
_, err := conn.WriteToUDPAddrPort(sealed, addr)
if err != nil {
fmt.Println("send response over udp:", err)
r.HumanLogf("%s Failed to send updated node over udp: %s", cliui.Timestamp(time.Now()), err)
return false
}
return true
Expand Down Expand Up @@ -216,14 +221,11 @@ func (r *Receive) ListenOverlaySTUN(ctx context.Context) (<-chan struct{}, error

// our first STUN response
if !r.stunIP.IsValid() {
fmt.Println(cliui.Timestamp(time.Now()), "STUN address is", cliui.Code(stunAddrPort.String()))
r.HumanLogf("STUN address is %s", cliui.Code(stunAddrPort.String()))
}

if r.stunIP.IsValid() && r.stunIP.Compare(stunAddrPort) != 0 {
r.Logger.Warn("STUN address changed, this may cause issues",
"old_ip", r.stunIP.String(),
"new_ip", stunAddrPort.String(),
)
r.HumanLogf(pretty.Sprintf(cliui.DefaultStyles.Warn, "STUN address changed, this may cause issues; %s->%s", r.stunIP.String(), stunAddrPort.String()))
}
r.stunIP = stunAddrPort
closeIPChanOnce.Do(func() {
Expand All @@ -234,7 +236,7 @@ func (r *Receive) ListenOverlaySTUN(ctx context.Context) (<-chan struct{}, error

res, key, err := r.handleNextMessage(buf, "STUN")
if err != nil {
fmt.Println(cliui.Timestamp(time.Now()), "Failed to handle overlay message:", err.Error())
r.HumanLogf("Failed to handle overlay message: %s", err.Error())
continue
}

Expand All @@ -243,7 +245,7 @@ func (r *Receive) ListenOverlaySTUN(ctx context.Context) (<-chan struct{}, error
if res != nil {
_, err = conn.WriteToUDPAddrPort(res, addr)
if err != nil {
fmt.Println(cliui.Timestamp(time.Now()), "Failed to send overlay response over STUN:", err.Error())
r.HumanLogf("Failed to send overlay response over STUN: %s", err.Error())
return
}
}
Expand Down Expand Up @@ -285,7 +287,7 @@ func (r *Receive) ListenOverlayDERP(ctx context.Context) error {
peers.Range(func(_, derpKey key.NodePublic) bool {
err = c.Send(derpKey, sealed)
if err != nil {
fmt.Println("send response over derp:", err)
r.HumanLogf("Send updated node over DERP: %s", err)
return false
}
return true
Expand All @@ -304,7 +306,7 @@ func (r *Receive) ListenOverlayDERP(ctx context.Context) error {
case derp.ReceivedPacket:
res, key, err := r.handleNextMessage(msg.Data, "DERP")
if err != nil {
fmt.Println(cliui.Timestamp(time.Now()), "Failed to handle overlay message:", err.Error())
r.HumanLogf("Failed to handle overlay message: %s", err.Error())
continue
}

Expand All @@ -313,7 +315,7 @@ func (r *Receive) ListenOverlayDERP(ctx context.Context) error {
if res != nil {
err = c.Send(msg.Source, res)
if err != nil {
fmt.Println(cliui.Timestamp(time.Now()), "Failed to send overlay response over derp:", err.Error())
r.HumanLogf("Failed to send overlay response over derp: %s", err.Error())
return err
}
}
Expand Down Expand Up @@ -350,9 +352,9 @@ func (r *Receive) handleNextMessage(msg []byte, system string) (resRaw []byte, n
if h := ovMsg.HostInfo.Hostname; h != "" {
hostname = h
}
fmt.Println(cliui.Timestamp(time.Now()), "Received connection request over", system, "from", cliui.Keyword(fmt.Sprintf("%s@%s", username, hostname)))
r.HumanLogf("%s Received connection request over %s from %s", cliui.Timestamp(time.Now()), system, cliui.Keyword(fmt.Sprintf("%s@%s", username, hostname)))
case messageTypeNodeUpdate:
fmt.Println(cliui.Timestamp(time.Now()), "Received updated node from", cliui.Code(ovMsg.Node.Key.String()))
r.HumanLogf("%s Received updated node from %s", cliui.Timestamp(time.Now()), cliui.Code(ovMsg.Node.Key.String()))
r.in <- &ovMsg.Node
res.Typ = messageTypeNodeUpdate
if lastNode := r.lastNode.Load(); lastNode != nil {
Expand Down