Skip to content

Commit 31e8563

Browse files
committed
Feature: disconnect agent from API on signal
If the agent receives a SIGUSR1 signal, it will disconnect from the API, but it will continue running checks. This makes it possible for a new version of the agent to connect to the API. When that happens, it's possible to kill the running agent (SIGTERM or otherwise). This should provide for a smoother upgrade process. Signed-off-by: Marcelo E. Magallon <[email protected]>
1 parent b80ef5c commit 31e8563

File tree

3 files changed

+220
-10
lines changed

3 files changed

+220
-10
lines changed

README.md

+15
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,18 @@ Now you should have the agent reporting as private probe, and running checks (if
8484

8585
#### Deploy it using Kubernetes
8686
See [examples/kubernetes](./examples/kubernetes) for the documentation and example yaml files
87+
88+
Signals
89+
-------
90+
91+
The agent traps the following signals:
92+
93+
* SIGTERM: The agent tries to clean up and shut down in an orderly
94+
manner.
95+
* SIGUSR1: The agent disconnects from the API but keeps running checks.
96+
After 1 minute elapses, the agent will try to reconnect to the API and
97+
keep trying until it succeeds or it's killed. One possible use case is
98+
upgrading a running agent with a newer version: after SIGUSR1 is sent,
99+
the agent disconnects, allowing another agent to connect in its place.
100+
If the new agent fails to connect, the old agent will reconnect and
101+
take it from there.

internal/checks/checks.go

+117-10
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@ import (
55
"errors"
66
"fmt"
77
"io"
8+
"os"
9+
"os/signal"
810
"strconv"
911
"sync"
12+
"sync/atomic"
13+
"syscall"
1014
"time"
1115

1216
"github.com/grafana/synthetic-monitoring-agent/internal/feature"
@@ -24,8 +28,9 @@ import (
2428
)
2529

2630
var (
27-
errNotAuthorized = errors.New("probe not authorized")
28-
errTransportClosing = errors.New("transport closing")
31+
errNotAuthorized = errors.New("probe not authorized")
32+
errTransportClosing = errors.New("transport closing")
33+
errProbeUnregistered = errors.New("probe no longer registered")
2934
)
3035

3136
// Updater represents a probe along with the collection of scrapers
@@ -219,16 +224,29 @@ func (c *Updater) Run(ctx context.Context) error {
219224
Msg("context cancelled, closing updater")
220225
return nil
221226

227+
case errors.Is(err, errProbeUnregistered):
228+
// Probe unregistered itself from the API, wait
229+
// until attempting to reconnect again.
230+
c.logger.Warn().
231+
Str("connection_state", c.api.conn.GetState().String()).
232+
Msg("unregistered probe in API, sleeping for 1 minute...")
233+
234+
if err := sleepCtx(ctx, 1*time.Minute); err != nil {
235+
return err
236+
}
237+
222238
default:
223239
c.logger.Warn().
224240
Err(err).
225241
Str("connection_state", c.api.conn.GetState().String()).
226242
Msg("handling check changes")
243+
227244
// TODO(mem): this might be a transient error (e.g. bad connection). We probably need to
228245
// fine-tune GRPPC's backoff parameters. We might also need to keep count of the reconnects, and
229246
// give up if they hit some threshold?
230-
time.Sleep(2 * time.Second)
231-
continue
247+
if err := sleepCtx(ctx, 2*time.Second); err != nil {
248+
return err
249+
}
232250
}
233251
}
234252
}
@@ -240,7 +258,7 @@ func (c *Updater) loop(ctx context.Context) error {
240258

241259
grpcErrorHandler := func(action string, err error) error {
242260
status, ok := status.FromError(err)
243-
c.logger.Error().Err(err).Uint32("code", uint32(status.Code())).Msg(status.Message())
261+
c.logger.Error().Err(err).Str("action", action).Uint32("code", uint32(status.Code())).Msg(status.Message())
244262

245263
switch {
246264
case !ok:
@@ -297,18 +315,79 @@ func (c *Updater) loop(ctx context.Context) error {
297315
"buildstamp": version.Buildstamp(),
298316
}).Set(1)
299317

300-
cc, err := client.GetChanges(ctx, &sm.Void{})
301-
if err != nil {
302-
return grpcErrorHandler("requesting changes from synthetic-monitoring-api", err)
318+
// Create a child context so that we can communicate to the
319+
// signal handler that we are done.
320+
sigCtx, cancel := context.WithCancel(ctx)
321+
defer cancel()
322+
323+
// We get _another_ context from the signal handler that we can
324+
// use tell the GRPC client that we need to break out. We have
325+
// multiple ways of cancelling the context (another signal
326+
// elsewhere in the system communicated through the parent
327+
// context; cancelling the child context because we are
328+
// returning from this function; cancelling the new context
329+
// because the signal fired), so we need an additional way of
330+
// telling them apart.
331+
sigCtx, signalFired := installSignalHandler(sigCtx)
332+
333+
action := "requesting changes from synthetic-monitoring-api"
334+
cc, err := client.GetChanges(sigCtx, &sm.Void{})
335+
if err == nil {
336+
action = "getting changes from synthetic-monitoring-api"
337+
// processChanges uses the context in its first
338+
// argument to create scrapers. This means that
339+
// cancelling that context cancels all the running
340+
// scrapers. That's why we are passing the _original_
341+
// context, not sigCtx, so that scrapers are _not_
342+
// stopped if the signal is trapped. We want scrapers to
343+
// continue running in case the agent is _not_ killed.
344+
err = c.processChanges(ctx, cc)
303345
}
304346

305-
if err := c.processChanges(ctx, cc); err != nil {
306-
return grpcErrorHandler("getting changes from synthetic-monitoring-api", err)
347+
if err != nil {
348+
if atomic.LoadInt32(signalFired) == 1 {
349+
return errProbeUnregistered
350+
}
351+
352+
return grpcErrorHandler(action, err)
307353
}
308354

309355
return nil
310356
}
311357

358+
// installSignalHandler installs a signal handler for SIGUSR1.
359+
//
360+
// The returned context's Done channel is closed if the signal is
361+
// delivered. To make it simpler to determine if the signal was
362+
// delivered, a value of 1 is written to the location pointed to by the
363+
// returned int32 pointer.
364+
//
365+
// If the provided context's Done channel is closed before the signal is
366+
// delivered, the signal handler is removed and the returned context's
367+
// Done channel is closed, too. It's the callers responsibility to
368+
// cancel the provided context if it's no longer interested in the
369+
// signal.
370+
func installSignalHandler(ctx context.Context) (context.Context, *int32) {
371+
sigCtx, cancel := context.WithCancel(ctx)
372+
373+
fired := new(int32)
374+
375+
sigCh := make(chan os.Signal, 1)
376+
signal.Notify(sigCh, syscall.SIGUSR1)
377+
378+
go func() {
379+
select {
380+
case <-sigCh:
381+
atomic.StoreInt32(fired, 1)
382+
cancel()
383+
case <-ctx.Done():
384+
}
385+
signal.Stop(sigCh)
386+
}()
387+
388+
return sigCtx, fired
389+
}
390+
312391
func (c *Updater) processChanges(ctx context.Context, cc sm.Checks_GetChangesClient) error {
313392
firstBatch := true
314393

@@ -317,6 +396,9 @@ func (c *Updater) processChanges(ctx context.Context, cc sm.Checks_GetChangesCli
317396
case <-cc.Context().Done():
318397
return nil
319398

399+
case <-ctx.Done():
400+
return ctx.Err()
401+
320402
default:
321403
switch msg, err := cc.Recv(); err {
322404
case nil:
@@ -472,6 +554,10 @@ func (c *Updater) handleFirstBatch(ctx context.Context, changes *sm.Changes) {
472554
continue
473555
}
474556

557+
c.logger.Debug().
558+
Int64("check_id", id).
559+
Msg("stopping scraper during first batch handling")
560+
475561
checkType := scraper.CheckType().String()
476562
scraper.Stop()
477563

@@ -599,3 +685,24 @@ func (c *Updater) addAndStartScraperWithLock(ctx context.Context, check sm.Check
599685

600686
return nil
601687
}
688+
689+
// sleepCtx is like time.Sleep, but it pays attention to the
690+
// cancellation of the provided context.
691+
func sleepCtx(ctx context.Context, d time.Duration) error {
692+
var err error
693+
694+
timer := time.NewTimer(d)
695+
696+
select {
697+
case <-ctx.Done():
698+
err = ctx.Err()
699+
700+
if !timer.Stop() {
701+
<-timer.C
702+
}
703+
704+
case <-timer.C:
705+
}
706+
707+
return err
708+
}

internal/checks/checks_test.go

+88
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package checks
22

33
import (
4+
"context"
5+
"sync/atomic"
6+
"syscall"
47
"testing"
8+
"time"
59

610
"github.com/grafana/synthetic-monitoring-agent/internal/feature"
711
"github.com/grafana/synthetic-monitoring-agent/internal/pusher"
@@ -53,3 +57,87 @@ func TestNewUpdater(t *testing.T) {
5357
})
5458
}
5559
}
60+
61+
func TestInstallSignalHandler(t *testing.T) {
62+
testcases := map[string]func(t *testing.T){
63+
"signal": func(t *testing.T) {
64+
// verify that the signal context is done after
65+
// receiving the signal, and that the signal is
66+
// correctly reported as having fired.
67+
68+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
69+
defer cancel()
70+
sigCtx, signalFired := installSignalHandler(ctx)
71+
require.NotNil(t, sigCtx)
72+
require.NotNil(t, signalFired)
73+
require.NoError(t, syscall.Kill(syscall.Getpid(), syscall.SIGUSR1))
74+
75+
select {
76+
case <-ctx.Done():
77+
t.Fatal("context timeout expired")
78+
case <-sigCtx.Done():
79+
require.Equal(t, int32(1), atomic.LoadInt32(signalFired))
80+
}
81+
},
82+
83+
"no signal": func(t *testing.T) {
84+
// verify that the signal context is done after
85+
// the parrent context is done, and that the
86+
// signal is correctly reported as not having
87+
// fired.
88+
89+
ctx, cancel := context.WithCancel(context.Background())
90+
sigCtx, signalFired := installSignalHandler(ctx)
91+
require.NotNil(t, sigCtx)
92+
require.NotNil(t, signalFired)
93+
94+
cancel()
95+
96+
timeout := 100 * time.Millisecond
97+
timer := time.NewTimer(timeout)
98+
defer timer.Stop()
99+
100+
select {
101+
case <-timer.C:
102+
t.Fatalf("signal context not cancelled after %s", timeout)
103+
case <-sigCtx.Done():
104+
require.Equal(t, int32(0), atomic.LoadInt32(signalFired))
105+
}
106+
},
107+
}
108+
109+
for name, f := range testcases {
110+
t.Run(name, f)
111+
}
112+
}
113+
114+
func TestSleepCtx(t *testing.T) {
115+
var (
116+
veryShort = 1 * time.Microsecond
117+
long = 10 * time.Second
118+
)
119+
120+
// make sure errors are reported correctly
121+
122+
ctx := context.Background()
123+
err := sleepCtx(ctx, veryShort)
124+
require.NoError(t, err)
125+
126+
ctx, cancel := context.WithCancel(context.Background())
127+
cancel()
128+
err = sleepCtx(ctx, long)
129+
require.Error(t, err)
130+
require.ErrorIs(t, err, context.Canceled)
131+
132+
ctx, cancel = context.WithTimeout(context.Background(), veryShort)
133+
err = sleepCtx(ctx, long)
134+
require.Error(t, err)
135+
require.ErrorIs(t, err, context.DeadlineExceeded)
136+
cancel()
137+
138+
ctx, cancel = context.WithTimeout(context.Background(), long)
139+
cancel()
140+
err = sleepCtx(ctx, long)
141+
require.Error(t, err)
142+
require.ErrorIs(t, err, context.Canceled)
143+
}

0 commit comments

Comments
 (0)