Skip to content

Commit 0d073a8

Browse files
author
Patrick Bajao
committed
Merge branch 'id-implement-client-keep-alive' into 'main'
Implement ClientKeepAlive option See merge request gitlab-org/gitlab-shell!622
2 parents 26a092c + a16dcb3 commit 0d073a8

File tree

6 files changed

+97
-19
lines changed

6 files changed

+97
-19
lines changed

config.yml.example

+2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ sshd:
7676
web_listen: "localhost:9122"
7777
# Maximum number of concurrent sessions allowed on a single SSH connection. Defaults to 10.
7878
concurrent_sessions_limit: 10
79+
# Sets an interval after which server will send keepalive message to a client
80+
client_alive_interval: 15
7981
# The server waits for this time (in seconds) for the ongoing connections to complete before shutting down. Defaults to 10.
8082
grace_period: 10
8183
# The endpoint that returns 200 OK if the server is ready to receive incoming connections; otherwise, it returns 503 Service Unavailable. Defaults to "/start".

internal/config/config.go

+21-15
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ const (
2222
)
2323

2424
type ServerConfig struct {
25-
Listen string `yaml:"listen,omitempty"`
26-
ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
27-
ProxyPolicy string `yaml:"proxy_policy,omitempty"`
28-
WebListen string `yaml:"web_listen,omitempty"`
29-
ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
30-
GracePeriodSeconds uint64 `yaml:"grace_period"`
31-
ReadinessProbe string `yaml:"readiness_probe"`
32-
LivenessProbe string `yaml:"liveness_probe"`
33-
HostKeyFiles []string `yaml:"host_key_files,omitempty"`
25+
Listen string `yaml:"listen,omitempty"`
26+
ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
27+
ProxyPolicy string `yaml:"proxy_policy,omitempty"`
28+
WebListen string `yaml:"web_listen,omitempty"`
29+
ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
30+
ClientAliveIntervalSeconds int64 `yaml:"client_alive_interval,omitempty"`
31+
GracePeriodSeconds uint64 `yaml:"grace_period"`
32+
ReadinessProbe string `yaml:"readiness_probe"`
33+
LivenessProbe string `yaml:"liveness_probe"`
34+
HostKeyFiles []string `yaml:"host_key_files,omitempty"`
3435
}
3536

3637
type HttpSettingsConfig struct {
@@ -75,12 +76,13 @@ var (
7576
}
7677

7778
DefaultServerConfig = ServerConfig{
78-
Listen: "[::]:22",
79-
WebListen: "localhost:9122",
80-
ConcurrentSessionsLimit: 10,
81-
GracePeriodSeconds: 10,
82-
ReadinessProbe: "/start",
83-
LivenessProbe: "/health",
79+
Listen: "[::]:22",
80+
WebListen: "localhost:9122",
81+
ConcurrentSessionsLimit: 10,
82+
GracePeriodSeconds: 10,
83+
ClientAliveIntervalSeconds: 15,
84+
ReadinessProbe: "/start",
85+
LivenessProbe: "/health",
8486
HostKeyFiles: []string{
8587
"/run/secrets/ssh-hostkeys/ssh_host_rsa_key",
8688
"/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key",
@@ -89,6 +91,10 @@ var (
8991
}
9092
)
9193

94+
func (sc *ServerConfig) ClientAliveInterval() time.Duration {
95+
return time.Duration(sc.ClientAliveIntervalSeconds) * time.Second
96+
}
97+
9298
func (sc *ServerConfig) GracePeriod() time.Duration {
9399
return time.Duration(sc.GracePeriodSeconds) * time.Second
94100
}

internal/sshd/connection.go

+30-2
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,41 @@ import (
77
"golang.org/x/crypto/ssh"
88
"golang.org/x/sync/semaphore"
99

10+
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
1011
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
1112

1213
"gitlab.com/gitlab-org/labkit/log"
1314
)
1415

16+
const KeepAliveMsg = "[email protected]"
17+
1518
type connection struct {
19+
cfg *config.Config
1620
concurrentSessions *semaphore.Weighted
1721
remoteAddr string
22+
sconn *ssh.ServerConn
1823
}
1924

2025
type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
2126

22-
func newConnection(maxSessions int64, remoteAddr string) *connection {
27+
func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
2328
return &connection{
24-
concurrentSessions: semaphore.NewWeighted(maxSessions),
29+
cfg: cfg,
30+
concurrentSessions: semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit),
2531
remoteAddr: remoteAddr,
32+
sconn: sconn,
2633
}
2734
}
2835

2936
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
3037
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
3138

39+
if c.cfg.Server.ClientAliveIntervalSeconds > 0 {
40+
ticker := time.NewTicker(c.cfg.Server.ClientAliveInterval())
41+
defer ticker.Stop()
42+
go c.sendKeepAliveMsg(ctx, ticker)
43+
}
44+
3245
for newChannel := range chans {
3346
ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
3447
if newChannel.ChannelType() != "session" {
@@ -68,3 +81,18 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
6881
}()
6982
}
7083
}
84+
85+
func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) {
86+
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
87+
88+
for {
89+
select {
90+
case <-ctx.Done():
91+
return
92+
case <-ticker.C:
93+
ctxlog.Debug("session: handleShell: send keepalive message to a client")
94+
95+
c.sconn.SendRequest(KeepAliveMsg, true, nil)
96+
}
97+
}
98+
}

internal/sshd/connection_test.go

+42-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@ package sshd
33
import (
44
"context"
55
"errors"
6+
"sync"
67
"testing"
8+
"time"
79

810
"github.com/stretchr/testify/require"
911
"golang.org/x/crypto/ssh"
12+
13+
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
1014
)
1115

1216
type rejectCall struct {
@@ -47,8 +51,32 @@ func (f *fakeNewChannel) ExtraData() []byte {
4751
return f.extraData
4852
}
4953

54+
type fakeConn struct {
55+
ssh.Conn
56+
57+
sentRequestName string
58+
mu sync.Mutex
59+
}
60+
61+
func (f *fakeConn) SentRequestName() string {
62+
f.mu.Lock()
63+
defer f.mu.Unlock()
64+
65+
return f.sentRequestName
66+
}
67+
68+
func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
69+
f.mu.Lock()
70+
defer f.mu.Unlock()
71+
72+
f.sentRequestName = name
73+
74+
return true, nil, nil
75+
}
76+
5077
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
51-
conn := newConnection(sessionsNum, "127.0.0.1:50000")
78+
cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum, ClientAliveIntervalSeconds: 1}}
79+
conn := newConnection(cfg, "127.0.0.1:50000", &ssh.ServerConn{&fakeConn{}, nil})
5280

5381
chans := make(chan ssh.NewChannel, 1)
5482
chans <- newChannel
@@ -145,3 +173,16 @@ func TestAcceptSessionFails(t *testing.T) {
145173

146174
require.False(t, channelHandled)
147175
}
176+
177+
func TestClientAliveInterval(t *testing.T) {
178+
f := &fakeConn{}
179+
180+
conn := newConnection(&config.Config{}, "127.0.0.1:50000", &ssh.ServerConn{f, nil})
181+
182+
ticker := time.NewTicker(time.Millisecond)
183+
defer ticker.Stop()
184+
185+
go conn.sendKeepAliveMsg(context.Background(), ticker)
186+
187+
require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond)
188+
}

internal/sshd/sshd.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
180180

181181
started := time.Now()
182182
var establishSessionDuration float64
183-
conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr)
183+
conn := newConnection(s.Config, remoteAddr, sconn)
184184
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
185185
establishSessionDuration = time.Since(started).Seconds()
186186
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)

internal/sshd/sshd_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
265265
cfg.User = user
266266
cfg.Server.Listen = serverUrl
267267
cfg.Server.ConcurrentSessionsLimit = 1
268+
cfg.Server.ClientAliveIntervalSeconds = 15
268269
cfg.Server.HostKeyFiles = []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}
269270

270271
s, err := NewServer(cfg)

0 commit comments

Comments
 (0)