Skip to content

Commit ff5f267

Browse files
oneirocosmsawka
andauthored
WSL Updates for New Architecture (#1756)
This adapts most of the WSL code to follow the new architecture that ssh uses. --------- Co-authored-by: sawka <[email protected]>
1 parent b7dca41 commit ff5f267

File tree

19 files changed

+1256
-994
lines changed

19 files changed

+1256
-994
lines changed

cmd/server/main-server.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import (
3535
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
3636
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshserver"
3737
"github.com/wavetermdev/waveterm/pkg/wshutil"
38-
"github.com/wavetermdev/waveterm/pkg/wsl"
38+
"github.com/wavetermdev/waveterm/pkg/wslconn"
3939
"github.com/wavetermdev/waveterm/pkg/wstore"
4040
)
4141

@@ -145,7 +145,7 @@ func beforeSendActivityUpdate(ctx context.Context) {
145145
activity.Blocks, _ = wstore.DBGetBlockViewCounts(ctx)
146146
activity.NumWindows, _ = wstore.DBGetCount[*waveobj.Window](ctx)
147147
activity.NumSSHConn = conncontroller.GetNumSSHHasConnected()
148-
activity.NumWSLConn = wsl.GetNumWSLHasConnected()
148+
activity.NumWSLConn = wslconn.GetNumWSLHasConnected()
149149
activity.NumWSNamed, activity.NumWS, _ = wstore.DBGetWSCounts(ctx)
150150
err := telemetry.UpdateActivity(ctx, activity)
151151
if err != nil {

pkg/blockcontroller/blockcontroller.go

+8-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import (
3333
"github.com/wavetermdev/waveterm/pkg/wshrpc"
3434
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
3535
"github.com/wavetermdev/waveterm/pkg/wshutil"
36-
"github.com/wavetermdev/waveterm/pkg/wsl"
36+
"github.com/wavetermdev/waveterm/pkg/wslconn"
3737
"github.com/wavetermdev/waveterm/pkg/wstore"
3838
)
3939

@@ -369,18 +369,22 @@ func (bc *BlockController) setupAndStartShellProcess(logCtx context.Context, rc
369369
credentialCtx, cancelFunc := context.WithTimeout(context.Background(), 60*time.Second)
370370
defer cancelFunc()
371371

372-
wslConn := wsl.GetWslConn(credentialCtx, wslName, false)
372+
wslConn := wslconn.GetWslConn(credentialCtx, wslName, false)
373373
connStatus := wslConn.DeriveConnStatus()
374374
if connStatus.Status != conncontroller.Status_Connected {
375375
return nil, fmt.Errorf("not connected, cannot start shellproc")
376376
}
377377

378378
// create jwt
379379
if !blockMeta.GetBool(waveobj.MetaKey_CmdNoWsh, false) {
380-
jwtStr, err := wshutil.MakeClientJWTToken(wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}, wslConn.GetDomainSocketName())
380+
sockName := wslConn.GetDomainSocketName()
381+
rpcContext := wshrpc.RpcContext{TabId: bc.TabId, BlockId: bc.BlockId, Conn: wslConn.GetName()}
382+
jwtStr, err := wshutil.MakeClientJWTToken(rpcContext, sockName)
381383
if err != nil {
382384
return nil, fmt.Errorf("error making jwt token: %w", err)
383385
}
386+
swapToken.SockName = sockName
387+
swapToken.RpcContext = &rpcContext
384388
swapToken.Env[wshutil.WaveJwtTokenVarName] = jwtStr
385389
cmdOpts.Env[wshutil.WaveJwtTokenVarName] = jwtStr
386390
}
@@ -747,7 +751,7 @@ func CheckConnStatus(blockId string) error {
747751
}
748752
if strings.HasPrefix(connName, "wsl://") {
749753
distroName := strings.TrimPrefix(connName, "wsl://")
750-
conn := wsl.GetWslConn(context.Background(), distroName, false)
754+
conn := wslconn.GetWslConn(context.Background(), distroName, false)
751755
connStatus := conn.DeriveConnStatus()
752756
if connStatus.Status != conncontroller.Status_Connected {
753757
return fmt.Errorf("not connected: %s", connStatus.Status)

pkg/genconn/wsl-impl.go

+15-11
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
1-
//go:build windows
2-
31
// Copyright 2025, Command Line Inc.
42
// SPDX-License-Identifier: Apache-2.0
53

64
package genconn
75

86
import (
7+
"context"
98
"fmt"
109
"io"
1110
"sync"
1211

13-
"github.com/ubuntu/gowsl"
12+
"github.com/wavetermdev/waveterm/pkg/wsl"
1413
)
1514

1615
var _ ShellClient = (*WSLShellClient)(nil)
1716

1817
type WSLShellClient struct {
19-
distro *gowsl.Distro
18+
distro *wsl.Distro
2019
}
2120

22-
func MakeWSLShellClient(distro *gowsl.Distro) *WSLShellClient {
21+
func MakeWSLShellClient(distro *wsl.Distro) *WSLShellClient {
2322
return &WSLShellClient{distro: distro}
2423
}
2524

@@ -28,8 +27,8 @@ func (c *WSLShellClient) MakeProcessController(cmdSpec CommandSpec) (ShellProces
2827
}
2928

3029
type WSLProcessController struct {
31-
distro *gowsl.Distro
32-
cmd *gowsl.Cmd
30+
distro *wsl.Distro
31+
cmd *wsl.WslCmd
3332
lock *sync.Mutex
3433
once *sync.Once
3534
stdinPiped bool
@@ -40,13 +39,13 @@ type WSLProcessController struct {
4039
cmdSpec CommandSpec
4140
}
4241

43-
func MakeWSLProcessController(distro *gowsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) {
42+
func MakeWSLProcessController(distro *wsl.Distro, cmdSpec CommandSpec) (*WSLProcessController, error) {
4443
fullCmd, err := BuildShellCommand(cmdSpec)
4544
if err != nil {
4645
return nil, fmt.Errorf("failed to build shell command: %w", err)
4746
}
4847

49-
cmd := distro.Command(nil, fullCmd)
48+
cmd := distro.WslCommand(context.Background(), fullCmd)
5049
if cmd == nil {
5150
return nil, fmt.Errorf("failed to create WSL command")
5251
}
@@ -87,9 +86,14 @@ func (w *WSLProcessController) Kill() {
8786
w.lock.Lock()
8887
defer w.lock.Unlock()
8988

90-
if w.cmd != nil && w.cmd.Process != nil {
91-
w.cmd.Process.Kill()
89+
if w.cmd == nil {
90+
return
91+
}
92+
process := w.cmd.GetProcess()
93+
if process == nil {
94+
return
9295
}
96+
process.Kill()
9397
}
9498

9599
func (w *WSLProcessController) StdinPipe() (io.WriteCloser, error) {

pkg/remote/conncontroller/conncontroller.go

+11-12
Original file line numberDiff line numberDiff line change
@@ -308,12 +308,13 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
308308
if err != nil {
309309
return false, "", "", fmt.Errorf("unable to start conn controller command: %w", err)
310310
}
311-
linesChan := wshutil.StreamToLinesChan(pipeRead)
312-
versionLine, err := wshutil.ReadLineWithTimeout(linesChan, 2*time.Second)
311+
linesChan := utilfn.StreamToLinesChan(pipeRead)
312+
versionLine, err := utilfn.ReadLineWithTimeout(linesChan, 2*time.Second)
313313
if err != nil {
314314
sshSession.Close()
315315
return false, "", "", fmt.Errorf("error reading wsh version: %w", err)
316316
}
317+
conn.Infof(ctx, "actual connnserverversion: %q\n", versionLine)
317318
conn.Infof(ctx, "got connserver version: %s\n", strings.TrimSpace(versionLine))
318319
isUpToDate, clientVersion, osArchStr, err := IsWshVersionUpToDate(ctx, versionLine)
319320
if err != nil {
@@ -326,11 +327,10 @@ func (conn *SSHConn) StartConnServer(ctx context.Context, afterUpdate bool) (boo
326327
}
327328
conn.Infof(ctx, "connserver up-to-date: %v\n", isUpToDate)
328329
if !isUpToDate {
329-
330330
sshSession.Close()
331331
return true, clientVersion, osArchStr, nil
332332
}
333-
jwtLine, err := wshutil.ReadLineWithTimeout(linesChan, 3*time.Second)
333+
jwtLine, err := utilfn.ReadLineWithTimeout(linesChan, 3*time.Second)
334334
if err != nil {
335335
sshSession.Close()
336336
return false, clientVersion, "", fmt.Errorf("error reading jwt status line: %w", err)
@@ -401,12 +401,6 @@ type WshInstallOpts struct {
401401
NoUserPrompt bool
402402
}
403403

404-
type WshInstallSkipError struct{}
405-
406-
func (wise *WshInstallSkipError) Error() string {
407-
return "skipping wsh installation"
408-
}
409-
410404
var queryTextTemplate = strings.TrimSpace(`
411405
Wave requires Wave Shell Extensions to be
412406
installed on %q
@@ -555,7 +549,7 @@ func (conn *SSHConn) Connect(ctx context.Context, connFlags *wshrpc.ConnKeywords
555549
}
556550
})
557551
if !connectAllowed {
558-
conn.Infof(ctx, "cannot connect to when status is %q\n", conn.GetStatus())
552+
conn.Infof(ctx, "cannot connect to %q when status is %q\n", conn.GetName(), conn.GetStatus())
559553
return fmt.Errorf("cannot connect to %q when status is %q", conn.GetName(), conn.GetStatus())
560554
}
561555
conn.Infof(ctx, "trying to connect to %q...\n", conn.GetName())
@@ -754,7 +748,12 @@ func (conn *SSHConn) connectInternal(ctx context.Context, connFlags *wshrpc.Conn
754748
conn.WithLock(func() {
755749
conn.Client = client
756750
})
757-
go conn.waitForDisconnect()
751+
go func() {
752+
defer func() {
753+
panichandler.PanicHandler("conncontroller:waitForDisconnect", recover())
754+
}()
755+
conn.waitForDisconnect()
756+
}()
758757
fmtAddr := knownhosts.Normalize(fmt.Sprintf("%s@%s", client.User(), client.RemoteAddr().String()))
759758
conn.Infof(ctx, "normalized knownhosts address: %s\n", fmtAddr)
760759
clientDisplayName := fmt.Sprintf("%s (%s)", conn.GetName(), fmtAddr)

pkg/service/clientservice/clientservice.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"github.com/wavetermdev/waveterm/pkg/wconfig"
1616
"github.com/wavetermdev/waveterm/pkg/wcore"
1717
"github.com/wavetermdev/waveterm/pkg/wshrpc"
18-
"github.com/wavetermdev/waveterm/pkg/wsl"
18+
"github.com/wavetermdev/waveterm/pkg/wslconn"
1919
"github.com/wavetermdev/waveterm/pkg/wstore"
2020
)
2121

@@ -42,7 +42,7 @@ func (cs *ClientService) GetTab(tabId string) (*waveobj.Tab, error) {
4242

4343
func (cs *ClientService) GetAllConnStatus(ctx context.Context) ([]wshrpc.ConnStatus, error) {
4444
sshStatuses := conncontroller.GetAllConnStatus()
45-
wslStatuses := wsl.GetAllConnStatus()
45+
wslStatuses := wslconn.GetAllConnStatus()
4646
return append(sshStatuses, wslStatuses...), nil
4747
}
4848

pkg/shellexec/shellexec.go

+65-50
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"github.com/creack/pty"
2222
"github.com/wavetermdev/waveterm/pkg/blocklogger"
2323
"github.com/wavetermdev/waveterm/pkg/panichandler"
24-
"github.com/wavetermdev/waveterm/pkg/remote"
2524
"github.com/wavetermdev/waveterm/pkg/remote/conncontroller"
2625
"github.com/wavetermdev/waveterm/pkg/util/pamparse"
2726
"github.com/wavetermdev/waveterm/pkg/util/shellutil"
@@ -30,7 +29,7 @@ import (
3029
"github.com/wavetermdev/waveterm/pkg/wshrpc"
3130
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
3231
"github.com/wavetermdev/waveterm/pkg/wshutil"
33-
"github.com/wavetermdev/waveterm/pkg/wsl"
32+
"github.com/wavetermdev/waveterm/pkg/wslconn"
3433
)
3534

3635
const DefaultGracefulKillWait = 400 * time.Millisecond
@@ -151,92 +150,108 @@ func (pp *PipePty) WriteString(s string) (n int, err error) {
151150
return pp.Write([]byte(s))
152151
}
153152

154-
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wsl.WslConn) (*ShellProc, error) {
155-
utilCtx, cancelFn := context.WithTimeout(ctx, 2*time.Second)
156-
defer cancelFn()
153+
func StartWslShellProc(ctx context.Context, termSize waveobj.TermSize, cmdStr string, cmdOpts CommandOptsType, conn *wslconn.WslConn) (*ShellProc, error) {
157154
client := conn.GetClient()
158-
shellPath := cmdOpts.ShellPath
155+
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)")
156+
connRoute := wshutil.MakeConnectionRouteId(conn.GetName())
157+
rpcClient := wshclient.GetBareRpcClient()
158+
remoteInfo, err := wshclient.RemoteGetInfoCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
159+
if err != nil {
160+
return nil, fmt.Errorf("unable to obtain client info: %w", err)
161+
}
162+
log.Printf("client info collected: %+#v", remoteInfo)
163+
var shellPath string
164+
if cmdOpts.ShellPath != "" {
165+
conn.Infof(ctx, "using shell path from command opts: %s\n", cmdOpts.ShellPath)
166+
shellPath = cmdOpts.ShellPath
167+
}
168+
configShellPath := conn.GetConfigShellPath()
169+
if shellPath == "" && configShellPath != "" {
170+
conn.Infof(ctx, "using shell path from config (conn:shellpath): %s\n", configShellPath)
171+
shellPath = configShellPath
172+
}
173+
if shellPath == "" && remoteInfo.Shell != "" {
174+
conn.Infof(ctx, "using shell path detected on remote machine: %s\n", remoteInfo.Shell)
175+
shellPath = remoteInfo.Shell
176+
}
159177
if shellPath == "" {
160-
remoteShellPath, err := wsl.DetectShell(utilCtx, client)
161-
if err != nil {
162-
return nil, err
163-
}
164-
shellPath = remoteShellPath
178+
conn.Infof(ctx, "no shell path detected, using default (/bin/bash)\n")
179+
shellPath = "/bin/bash"
165180
}
166181
var shellOpts []string
167-
log.Printf("detected shell: %s", shellPath)
182+
var cmdCombined string
183+
log.Printf("detected shell %q for conn %q\n", shellPath, conn.GetName())
168184

169-
err := wsl.InstallClientRcFiles(utilCtx, client)
185+
err = wshclient.RemoteInstallRcFilesCommand(rpcClient, &wshrpc.RpcOpts{Route: connRoute, Timeout: 2000})
170186
if err != nil {
171187
log.Printf("error installing rc files: %v", err)
172188
return nil, err
173189
}
174-
175-
homeDir := wsl.GetHomeDir(utilCtx, client)
176-
shellOpts = append(shellOpts, "~", "-d", client.Name())
177-
178-
var subShellOpts []string
190+
shellOpts = append(shellOpts, cmdOpts.ShellOpts...)
191+
shellType := shellutil.GetShellTypeFromShellPath(shellPath)
192+
conn.Infof(ctx, "detected shell type: %s\n", shellType)
179193

180194
if cmdStr == "" {
181195
/* transform command in order to inject environment vars */
182-
if isBashShell(shellPath) {
183-
log.Printf("recognized as bash shell")
196+
if shellType == shellutil.ShellType_bash {
184197
// add --rcfile
185198
// cant set -l or -i with --rcfile
186-
subShellOpts = append(subShellOpts, "--rcfile", fmt.Sprintf(`%s/.waveterm/%s/.bashrc`, homeDir, shellutil.BashIntegrationDir))
187-
} else if isFishShell(shellPath) {
188-
carg := fmt.Sprintf(`"set -x PATH \"%s\"/.waveterm/%s $PATH"`, homeDir, shellutil.WaveHomeBinDir)
189-
subShellOpts = append(subShellOpts, "-C", carg)
190-
} else if wsl.IsPowershell(shellPath) {
199+
bashPath := fmt.Sprintf("~/.waveterm/%s/.bashrc", shellutil.BashIntegrationDir)
200+
shellOpts = append(shellOpts, "--rcfile", bashPath)
201+
} else if shellType == shellutil.ShellType_fish {
202+
if cmdOpts.Login {
203+
shellOpts = append(shellOpts, "-l")
204+
}
205+
// source the wave.fish file
206+
waveFishPath := fmt.Sprintf("~/.waveterm/%s/wave.fish", shellutil.FishIntegrationDir)
207+
carg := fmt.Sprintf(`"source %s"`, waveFishPath)
208+
shellOpts = append(shellOpts, "-C", carg)
209+
} else if shellType == shellutil.ShellType_pwsh {
210+
pwshPath := fmt.Sprintf("~/.waveterm/%s/wavepwsh.ps1", shellutil.PwshIntegrationDir)
191211
// powershell is weird about quoted path executables and requires an ampersand first
192212
shellPath = "& " + shellPath
193-
subShellOpts = append(subShellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", fmt.Sprintf("%s/.waveterm/%s/wavepwsh.ps1", homeDir, shellutil.PwshIntegrationDir))
213+
shellOpts = append(shellOpts, "-ExecutionPolicy", "Bypass", "-NoExit", "-File", pwshPath)
194214
} else {
195215
if cmdOpts.Login {
196-
subShellOpts = append(subShellOpts, "-l")
216+
shellOpts = append(shellOpts, "-l")
197217
}
198218
if cmdOpts.Interactive {
199-
subShellOpts = append(subShellOpts, "-i")
219+
shellOpts = append(shellOpts, "-i")
200220
}
201-
// can't set environment vars this way
202-
// will try to do later if possible
221+
// zdotdir setting moved to after session is created
203222
}
223+
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
204224
} else {
225+
// TODO check quoting of cmdStr
205226
shellPath = cmdStr
206-
if cmdOpts.Login {
207-
subShellOpts = append(subShellOpts, "-l")
208-
}
209-
if cmdOpts.Interactive {
210-
subShellOpts = append(subShellOpts, "-i")
211-
}
212-
subShellOpts = append(subShellOpts, "-c", cmdStr)
227+
shellOpts = append(shellOpts, "-c", cmdStr)
228+
cmdCombined = fmt.Sprintf("%s %s", shellPath, strings.Join(shellOpts, " "))
229+
}
230+
conn.Infof(ctx, "starting shell, using command: %s\n", cmdCombined)
231+
conn.Infof(ctx, "WSL-NEWSESSION (StartWslShellProc)\n")
232+
233+
if shellType == shellutil.ShellType_zsh {
234+
zshDir := fmt.Sprintf("~/.waveterm/%s", shellutil.ZshIntegrationDir)
235+
conn.Infof(ctx, "setting ZDOTDIR to %s\n", zshDir)
236+
cmdCombined = fmt.Sprintf(`ZDOTDIR=%s %s`, zshDir, cmdCombined)
213237
}
214238

215239
jwtToken, ok := cmdOpts.Env[wshutil.WaveJwtTokenVarName]
216240
if !ok {
217241
return nil, fmt.Errorf("no jwt token provided to connection")
218242
}
219-
if remote.IsPowershell(shellPath) {
220-
shellOpts = append(shellOpts, "--", fmt.Sprintf(`$env:%s=%s;`, wshutil.WaveJwtTokenVarName, jwtToken))
221-
} else {
222-
shellOpts = append(shellOpts, "--", fmt.Sprintf(`%s=%s`, wshutil.WaveJwtTokenVarName, jwtToken))
223-
}
243+
cmdCombined = fmt.Sprintf(`%s=%s %s`, wshutil.WaveJwtTokenVarName, jwtToken, cmdCombined)
224244

225-
if isZshShell(shellPath) {
226-
shellOpts = append(shellOpts, fmt.Sprintf(`ZDOTDIR=%s/.waveterm/%s`, homeDir, shellutil.ZshIntegrationDir))
227-
}
228-
shellOpts = append(shellOpts, shellPath)
229-
shellOpts = append(shellOpts, subShellOpts...)
230-
log.Printf("full cmd is: %s %s", "wsl.exe", strings.Join(shellOpts, " "))
231-
232-
ecmd := exec.Command("wsl.exe", shellOpts...)
245+
log.Printf("full combined command: %s", cmdCombined)
246+
ecmd := exec.Command("wsl.exe", "~", "-d", client.Name(), "--", "sh", "-c", cmdCombined)
233247
if termSize.Rows == 0 || termSize.Cols == 0 {
234248
termSize.Rows = shellutil.DefaultTermRows
235249
termSize.Cols = shellutil.DefaultTermCols
236250
}
237251
if termSize.Rows <= 0 || termSize.Cols <= 0 {
238252
return nil, fmt.Errorf("invalid term size: %v", termSize)
239253
}
254+
shellutil.AddTokenSwapEntry(cmdOpts.SwapToken)
240255
cmdPty, err := pty.StartWithSize(ecmd, &pty.Winsize{Rows: uint16(termSize.Rows), Cols: uint16(termSize.Cols)})
241256
if err != nil {
242257
return nil, err

0 commit comments

Comments
 (0)