Skip to content

Commit fbec7bc

Browse files
authored
Merge pull request #50 from arduino/fix_crashing_disc_handling
Fix panic when dealing with crashing discoveries
2 parents 6ae82f5 + cc38790 commit fbec7bc

File tree

3 files changed

+77
-17
lines changed

3 files changed

+77
-17
lines changed

Diff for: client.go

+12-17
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ func (disc *Client) jsonDecodeLoop(in io.Reader, outChan chan<- *discoveryMessag
125125
closeAndReportError := func(err error) {
126126
disc.statusMutex.Lock()
127127
disc.incomingMessagesError = err
128-
disc.statusMutex.Unlock()
129128
disc.stopSync()
130129
disc.killProcess()
130+
disc.statusMutex.Unlock()
131131
close(outChan)
132132
if err != nil {
133133
disc.logger.Errorf("Stopped decode loop: %v", err)
@@ -138,11 +138,7 @@ func (disc *Client) jsonDecodeLoop(in io.Reader, outChan chan<- *discoveryMessag
138138

139139
for {
140140
var msg discoveryMessage
141-
if err := decoder.Decode(&msg); errors.Is(err, io.EOF) {
142-
// This is fine :flames: we exit gracefully
143-
closeAndReportError(nil)
144-
return
145-
} else if err != nil {
141+
if err := decoder.Decode(&msg); err != nil {
146142
closeAndReportError(err)
147143
return
148144
}
@@ -184,7 +180,10 @@ func (disc *Client) waitMessage(timeout time.Duration) (*discoveryMessage, error
184180
select {
185181
case msg := <-disc.incomingMessagesChan:
186182
if msg == nil {
187-
return nil, disc.incomingMessagesError
183+
disc.statusMutex.Lock()
184+
err := disc.incomingMessagesError
185+
disc.statusMutex.Unlock()
186+
return nil, err
188187
}
189188
return msg, nil
190189
case <-time.After(timeout):
@@ -239,9 +238,6 @@ func (disc *Client) runProcess() error {
239238
}
240239

241240
func (disc *Client) killProcess() {
242-
disc.statusMutex.Lock()
243-
defer disc.statusMutex.Unlock()
244-
245241
disc.logger.Debugf("Killing discovery process")
246242
if process := disc.process; process != nil {
247243
disc.process = nil
@@ -270,7 +266,9 @@ func (disc *Client) Run() (err error) {
270266
if err == nil {
271267
return
272268
}
269+
disc.statusMutex.Lock()
273270
disc.killProcess()
271+
disc.statusMutex.Unlock()
274272
}()
275273

276274
if err = disc.sendCommand("HELLO 1 \"arduino-cli " + disc.userAgent + "\"\n"); err != nil {
@@ -287,8 +285,6 @@ func (disc *Client) Run() (err error) {
287285
} else if msg.ProtocolVersion > 1 {
288286
return fmt.Errorf("protocol version not supported: requested 1, got %d", msg.ProtocolVersion)
289287
}
290-
disc.statusMutex.Lock()
291-
defer disc.statusMutex.Unlock()
292288
return nil
293289
}
294290

@@ -307,8 +303,6 @@ func (disc *Client) Start() error {
307303
} else if strings.ToUpper(msg.Message) != "OK" {
308304
return fmt.Errorf("communication out of sync, expected 'OK', received '%s'", msg.Message)
309305
}
310-
disc.statusMutex.Lock()
311-
defer disc.statusMutex.Unlock()
312306
return nil
313307
}
314308

@@ -348,8 +342,10 @@ func (disc *Client) Quit() {
348342
if _, err := disc.waitMessage(time.Second * 5); err != nil {
349343
disc.logger.Errorf("Quitting discovery: %s", err)
350344
}
345+
disc.statusMutex.Lock()
351346
disc.stopSync()
352347
disc.killProcess()
348+
disc.statusMutex.Unlock()
353349
}
354350

355351
// List executes an enumeration of the ports and returns a list of the available
@@ -377,9 +373,6 @@ func (disc *Client) List() ([]*Port, error) {
377373
// The event channel must be consumed as quickly as possible since it may block the
378374
// discovery if it becomes full. The channel size is configurable.
379375
func (disc *Client) StartSync(size int) (<-chan *Event, error) {
380-
disc.statusMutex.Lock()
381-
defer disc.statusMutex.Unlock()
382-
383376
if err := disc.sendCommand("START_SYNC\n"); err != nil {
384377
return nil, err
385378
}
@@ -395,6 +388,8 @@ func (disc *Client) StartSync(size int) (<-chan *Event, error) {
395388
}
396389

397390
// In case there is already an existing event channel in use we close it before creating a new one.
391+
disc.statusMutex.Lock()
392+
defer disc.statusMutex.Unlock()
398393
disc.stopSync()
399394
c := make(chan *Event, size)
400395
disc.eventChan = c

Diff for: client_test.go

+56
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package discovery
1919

2020
import (
2121
"fmt"
22+
"io"
2223
"net"
2324
"testing"
2425
"time"
@@ -93,3 +94,58 @@ func TestDiscoveryStdioHandling(t *testing.T) {
9394

9495
require.False(t, disc.Alive())
9596
}
97+
98+
func TestClient(t *testing.T) {
99+
// Build dummy-discovery
100+
builder, err := paths.NewProcess(nil, "go", "build")
101+
require.NoError(t, err)
102+
builder.SetDir("dummy-discovery")
103+
require.NoError(t, builder.Run())
104+
105+
t.Run("WithDiscoveryCrashingOnStartup", func(t *testing.T) {
106+
// Run client with discovery crashing on startup
107+
cl := NewClient("1", "dummy-discovery/dummy-discovery", "--invalid")
108+
require.ErrorIs(t, cl.Run(), io.EOF)
109+
})
110+
111+
t.Run("WithDiscoveryCrashingWhileSendingCommands", func(t *testing.T) {
112+
// Run client with crashing discovery after 1 second
113+
cl := NewClient("1", "dummy-discovery/dummy-discovery", "-k")
114+
require.NoError(t, cl.Run())
115+
116+
time.Sleep(time.Second)
117+
118+
ch, err := cl.StartSync(20)
119+
require.Error(t, err)
120+
require.Nil(t, ch)
121+
})
122+
123+
t.Run("WithDiscoveryCrashingWhileStreamingEvents", func(t *testing.T) {
124+
// Run client with crashing discovery after 1 second
125+
cl := NewClient("1", "dummy-discovery/dummy-discovery", "-k")
126+
require.NoError(t, cl.Run())
127+
128+
ch, err := cl.StartSync(20)
129+
require.NoError(t, err)
130+
131+
time.Sleep(time.Second)
132+
133+
loop:
134+
for {
135+
select {
136+
case msg, ok := <-ch:
137+
if !ok {
138+
// Channel closed: Test passed
139+
fmt.Println("Event channel closed")
140+
break loop
141+
}
142+
fmt.Println("Recv: ", msg)
143+
case <-time.After(time.Second):
144+
t.Error("Crashing client did not close event channel")
145+
break loop
146+
}
147+
}
148+
149+
cl.Quit()
150+
})
151+
}

Diff for: dummy-discovery/args/args.go

+9
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package args
2020
import (
2121
"fmt"
2222
"os"
23+
"time"
2324
)
2425

2526
// Tag is the current git tag
@@ -38,6 +39,14 @@ func Parse() {
3839
fmt.Printf("dummy-discovery %s (build timestamp: %s)\n", Tag, Timestamp)
3940
os.Exit(0)
4041
}
42+
if arg == "-k" {
43+
// Emulate crashing discovery
44+
go func() {
45+
time.Sleep(time.Millisecond * 500)
46+
os.Exit(1)
47+
}()
48+
continue
49+
}
4150
fmt.Fprintf(os.Stderr, "invalid argument: %s\n", arg)
4251
os.Exit(1)
4352
}

0 commit comments

Comments
 (0)