Skip to content

Commit a16dcb3

Browse files
author
Igor Drozdov
committed
Implement ClientKeepAlive option
Git clients sometimes open a connection and leave it idling, like when compressing objects. Settings like timeout client in HAProxy might cause these idle connections to be terminated. Let's send the keepalive message in order to prevent a client from closing
1 parent 42cf058 commit a16dcb3

File tree

6 files changed

+97
-19
lines changed

6 files changed

+97
-19
lines changed

config.yml.example

+2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ sshd:
7676
web_listen: "localhost:9122"
7777
# Maximum number of concurrent sessions allowed on a single SSH connection. Defaults to 10.
7878
concurrent_sessions_limit: 10
79+
# Sets an interval after which server will send keepalive message to a client
80+
client_alive_interval: 15
7981
# The server waits for this time (in seconds) for the ongoing connections to complete before shutting down. Defaults to 10.
8082
grace_period: 10
8183
# The endpoint that returns 200 OK if the server is ready to receive incoming connections; otherwise, it returns 503 Service Unavailable. Defaults to "/start".

internal/config/config.go

+21-15
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ const (
2222
)
2323

2424
type ServerConfig struct {
25-
Listen string `yaml:"listen,omitempty"`
26-
ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
27-
ProxyPolicy string `yaml:"proxy_policy,omitempty"`
28-
WebListen string `yaml:"web_listen,omitempty"`
29-
ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
30-
GracePeriodSeconds uint64 `yaml:"grace_period"`
31-
ReadinessProbe string `yaml:"readiness_probe"`
32-
LivenessProbe string `yaml:"liveness_probe"`
33-
HostKeyFiles []string `yaml:"host_key_files,omitempty"`
25+
Listen string `yaml:"listen,omitempty"`
26+
ProxyProtocol bool `yaml:"proxy_protocol,omitempty"`
27+
ProxyPolicy string `yaml:"proxy_policy,omitempty"`
28+
WebListen string `yaml:"web_listen,omitempty"`
29+
ConcurrentSessionsLimit int64 `yaml:"concurrent_sessions_limit,omitempty"`
30+
ClientAliveIntervalSeconds int64 `yaml:"client_alive_interval,omitempty"`
31+
GracePeriodSeconds uint64 `yaml:"grace_period"`
32+
ReadinessProbe string `yaml:"readiness_probe"`
33+
LivenessProbe string `yaml:"liveness_probe"`
34+
HostKeyFiles []string `yaml:"host_key_files,omitempty"`
3435
}
3536

3637
type HttpSettingsConfig struct {
@@ -75,12 +76,13 @@ var (
7576
}
7677

7778
DefaultServerConfig = ServerConfig{
78-
Listen: "[::]:22",
79-
WebListen: "localhost:9122",
80-
ConcurrentSessionsLimit: 10,
81-
GracePeriodSeconds: 10,
82-
ReadinessProbe: "/start",
83-
LivenessProbe: "/health",
79+
Listen: "[::]:22",
80+
WebListen: "localhost:9122",
81+
ConcurrentSessionsLimit: 10,
82+
GracePeriodSeconds: 10,
83+
ClientAliveIntervalSeconds: 15,
84+
ReadinessProbe: "/start",
85+
LivenessProbe: "/health",
8486
HostKeyFiles: []string{
8587
"/run/secrets/ssh-hostkeys/ssh_host_rsa_key",
8688
"/run/secrets/ssh-hostkeys/ssh_host_ecdsa_key",
@@ -89,6 +91,10 @@ var (
8991
}
9092
)
9193

94+
func (sc *ServerConfig) ClientAliveInterval() time.Duration {
95+
return time.Duration(sc.ClientAliveIntervalSeconds) * time.Second
96+
}
97+
9298
func (sc *ServerConfig) GracePeriod() time.Duration {
9399
return time.Duration(sc.GracePeriodSeconds) * time.Second
94100
}

internal/sshd/connection.go

+30-2
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,41 @@ import (
77
"golang.org/x/crypto/ssh"
88
"golang.org/x/sync/semaphore"
99

10+
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
1011
"gitlab.com/gitlab-org/gitlab-shell/internal/metrics"
1112

1213
"gitlab.com/gitlab-org/labkit/log"
1314
)
1415

16+
const KeepAliveMsg = "[email protected]"
17+
1518
type connection struct {
19+
cfg *config.Config
1620
concurrentSessions *semaphore.Weighted
1721
remoteAddr string
22+
sconn *ssh.ServerConn
1823
}
1924

2025
type channelHandler func(context.Context, ssh.Channel, <-chan *ssh.Request)
2126

22-
func newConnection(maxSessions int64, remoteAddr string) *connection {
27+
func newConnection(cfg *config.Config, remoteAddr string, sconn *ssh.ServerConn) *connection {
2328
return &connection{
24-
concurrentSessions: semaphore.NewWeighted(maxSessions),
29+
cfg: cfg,
30+
concurrentSessions: semaphore.NewWeighted(cfg.Server.ConcurrentSessionsLimit),
2531
remoteAddr: remoteAddr,
32+
sconn: sconn,
2633
}
2734
}
2835

2936
func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, handler channelHandler) {
3037
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
3138

39+
if c.cfg.Server.ClientAliveIntervalSeconds > 0 {
40+
ticker := time.NewTicker(c.cfg.Server.ClientAliveInterval())
41+
defer ticker.Stop()
42+
go c.sendKeepAliveMsg(ctx, ticker)
43+
}
44+
3245
for newChannel := range chans {
3346
ctxlog.WithField("channel_type", newChannel.ChannelType()).Info("connection: handle: new channel requested")
3447
if newChannel.ChannelType() != "session" {
@@ -68,3 +81,18 @@ func (c *connection) handle(ctx context.Context, chans <-chan ssh.NewChannel, ha
6881
}()
6982
}
7083
}
84+
85+
func (c *connection) sendKeepAliveMsg(ctx context.Context, ticker *time.Ticker) {
86+
ctxlog := log.WithContextFields(ctx, log.Fields{"remote_addr": c.remoteAddr})
87+
88+
for {
89+
select {
90+
case <-ctx.Done():
91+
return
92+
case <-ticker.C:
93+
ctxlog.Debug("session: handleShell: send keepalive message to a client")
94+
95+
c.sconn.SendRequest(KeepAliveMsg, true, nil)
96+
}
97+
}
98+
}

internal/sshd/connection_test.go

+42-1
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@ package sshd
33
import (
44
"context"
55
"errors"
6+
"sync"
67
"testing"
8+
"time"
79

810
"github.com/stretchr/testify/require"
911
"golang.org/x/crypto/ssh"
12+
13+
"gitlab.com/gitlab-org/gitlab-shell/internal/config"
1014
)
1115

1216
type rejectCall struct {
@@ -47,8 +51,32 @@ func (f *fakeNewChannel) ExtraData() []byte {
4751
return f.extraData
4852
}
4953

54+
type fakeConn struct {
55+
ssh.Conn
56+
57+
sentRequestName string
58+
mu sync.Mutex
59+
}
60+
61+
func (f *fakeConn) SentRequestName() string {
62+
f.mu.Lock()
63+
defer f.mu.Unlock()
64+
65+
return f.sentRequestName
66+
}
67+
68+
func (f *fakeConn) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
69+
f.mu.Lock()
70+
defer f.mu.Unlock()
71+
72+
f.sentRequestName = name
73+
74+
return true, nil, nil
75+
}
76+
5077
func setup(sessionsNum int64, newChannel *fakeNewChannel) (*connection, chan ssh.NewChannel) {
51-
conn := newConnection(sessionsNum, "127.0.0.1:50000")
78+
cfg := &config.Config{Server: config.ServerConfig{ConcurrentSessionsLimit: sessionsNum, ClientAliveIntervalSeconds: 1}}
79+
conn := newConnection(cfg, "127.0.0.1:50000", &ssh.ServerConn{&fakeConn{}, nil})
5280

5381
chans := make(chan ssh.NewChannel, 1)
5482
chans <- newChannel
@@ -145,3 +173,16 @@ func TestAcceptSessionFails(t *testing.T) {
145173

146174
require.False(t, channelHandled)
147175
}
176+
177+
func TestClientAliveInterval(t *testing.T) {
178+
f := &fakeConn{}
179+
180+
conn := newConnection(&config.Config{}, "127.0.0.1:50000", &ssh.ServerConn{f, nil})
181+
182+
ticker := time.NewTicker(time.Millisecond)
183+
defer ticker.Stop()
184+
185+
go conn.sendKeepAliveMsg(context.Background(), ticker)
186+
187+
require.Eventually(t, func() bool { return KeepAliveMsg == f.SentRequestName() }, time.Second, time.Millisecond)
188+
}

internal/sshd/sshd.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func (s *Server) handleConn(ctx context.Context, nconn net.Conn) {
180180

181181
started := time.Now()
182182
var establishSessionDuration float64
183-
conn := newConnection(s.Config.Server.ConcurrentSessionsLimit, remoteAddr)
183+
conn := newConnection(s.Config, remoteAddr, sconn)
184184
conn.handle(ctx, chans, func(ctx context.Context, channel ssh.Channel, requests <-chan *ssh.Request) {
185185
establishSessionDuration = time.Since(started).Seconds()
186186
metrics.SshdSessionEstablishedDuration.Observe(establishSessionDuration)

internal/sshd/sshd_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ func setupServerWithConfig(t *testing.T, cfg *config.Config) *Server {
265265
cfg.User = user
266266
cfg.Server.Listen = serverUrl
267267
cfg.Server.ConcurrentSessionsLimit = 1
268+
cfg.Server.ClientAliveIntervalSeconds = 15
268269
cfg.Server.HostKeyFiles = []string{path.Join(testhelper.TestRoot, "certs/valid/server.key")}
269270

270271
s, err := NewServer(cfg)

0 commit comments

Comments
 (0)