Skip to content

Commit d852119

Browse files
authored
fix: Error handling in StreamingBatchWriter (#1913)
Actual clean up by @murarustefaan but I've done some minor updates to it and updated tests. ~Still need to test it E2E.~ Seems to work, handled the `panic: arrow/array: number of columns/fields mismatch [recovered]` error immediately (I tested with an old version of the S3 plugin, had to manually bump Arrow to v17)
1 parent b0d72e1 commit d852119

File tree

2 files changed

+103
-76
lines changed

2 files changed

+103
-76
lines changed

writers/streamingbatchwriter/streamingbatchwriter.go

+74-54
Original file line numberDiff line numberDiff line change
@@ -182,26 +182,28 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
182182
errCh := make(chan error)
183183
defer close(errCh)
184184

185-
go func() {
186-
for err := range errCh {
187-
w.logger.Err(err).Msg("error from StreamingBatchWriter")
188-
}
189-
}()
185+
for {
186+
select {
187+
case msg, ok := <-msgs:
188+
if !ok {
189+
return w.Close(ctx)
190+
}
190191

191-
for msg := range msgs {
192-
msgType := writers.MsgID(msg)
193-
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
194-
if err := w.Flush(ctx); err != nil {
192+
msgType := writers.MsgID(msg)
193+
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
194+
if err := w.Flush(ctx); err != nil {
195+
return err
196+
}
197+
}
198+
w.lastMsgType = msgType
199+
if err := w.startWorker(ctx, errCh, msg); err != nil {
195200
return err
196201
}
197-
}
198-
w.lastMsgType = msgType
199-
if err := w.startWorker(ctx, errCh, msg); err != nil {
202+
203+
case err := <-errCh:
200204
return err
201205
}
202206
}
203-
204-
return w.Close(ctx)
205207
}
206208

207209
func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error {
@@ -221,13 +223,14 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
221223
case *message.WriteMigrateTable:
222224
w.workersLock.Lock()
223225
defer w.workersLock.Unlock()
226+
224227
if w.migrateWorker != nil {
225228
w.migrateWorker.ch <- m
226229
return nil
227230
}
228-
ch := make(chan *message.WriteMigrateTable)
231+
229232
w.migrateWorker = &streamingWorkerManager[*message.WriteMigrateTable]{
230-
ch: ch,
233+
ch: make(chan *message.WriteMigrateTable),
231234
writeFunc: w.client.MigrateTable,
232235

233236
flush: make(chan chan bool),
@@ -241,17 +244,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
241244
w.workersWaitGroup.Add(1)
242245
go w.migrateWorker.run(ctx, &w.workersWaitGroup, tableName)
243246
w.migrateWorker.ch <- m
247+
244248
return nil
245249
case *message.WriteDeleteStale:
246250
w.workersLock.Lock()
247251
defer w.workersLock.Unlock()
252+
248253
if w.deleteStaleWorker != nil {
249254
w.deleteStaleWorker.ch <- m
250255
return nil
251256
}
252-
ch := make(chan *message.WriteDeleteStale)
257+
253258
w.deleteStaleWorker = &streamingWorkerManager[*message.WriteDeleteStale]{
254-
ch: ch,
259+
ch: make(chan *message.WriteDeleteStale),
255260
writeFunc: w.client.DeleteStale,
256261

257262
flush: make(chan chan bool),
@@ -265,19 +270,29 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
265270
w.workersWaitGroup.Add(1)
266271
go w.deleteStaleWorker.run(ctx, &w.workersWaitGroup, tableName)
267272
w.deleteStaleWorker.ch <- m
273+
268274
return nil
269275
case *message.WriteInsert:
270276
w.workersLock.RLock()
271-
wr, ok := w.insertWorkers[tableName]
277+
worker, ok := w.insertWorkers[tableName]
272278
w.workersLock.RUnlock()
273279
if ok {
274-
wr.ch <- m
280+
worker.ch <- m
275281
return nil
276282
}
277283

278-
ch := make(chan *message.WriteInsert)
279-
wr = &streamingWorkerManager[*message.WriteInsert]{
280-
ch: ch,
284+
w.workersLock.Lock()
285+
activeWorker, ok := w.insertWorkers[tableName]
286+
if ok {
287+
w.workersLock.Unlock()
288+
// some other goroutine could have already added the worker
289+
// just send the message to it & discard our allocated worker
290+
activeWorker.ch <- m
291+
return nil
292+
}
293+
294+
worker = &streamingWorkerManager[*message.WriteInsert]{
295+
ch: make(chan *message.WriteInsert),
281296
writeFunc: w.client.WriteTable,
282297

283298
flush: make(chan chan bool),
@@ -287,33 +302,27 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
287302
batchTimeout: w.batchTimeout,
288303
tickerFn: w.tickerFn,
289304
}
290-
w.workersLock.Lock()
291-
wrOld, ok := w.insertWorkers[tableName]
292-
if ok {
293-
w.workersLock.Unlock()
294-
// some other goroutine could have already added the worker
295-
// just send the message to it & discard our allocated worker
296-
wrOld.ch <- m
297-
return nil
298-
}
299-
w.insertWorkers[tableName] = wr
305+
306+
w.insertWorkers[tableName] = worker
300307
w.workersLock.Unlock()
301308

302309
w.workersWaitGroup.Add(1)
303-
go wr.run(ctx, &w.workersWaitGroup, tableName)
304-
ch <- m
310+
go worker.run(ctx, &w.workersWaitGroup, tableName)
311+
worker.ch <- m
312+
305313
return nil
306314
case *message.WriteDeleteRecord:
307315
w.workersLock.Lock()
308316
defer w.workersLock.Unlock()
317+
309318
if w.deleteRecordWorker != nil {
310319
w.deleteRecordWorker.ch <- m
311320
return nil
312321
}
313-
ch := make(chan *message.WriteDeleteRecord)
322+
314323
// TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
315324
w.deleteRecordWorker = &streamingWorkerManager[*message.WriteDeleteRecord]{
316-
ch: ch,
325+
ch: make(chan *message.WriteDeleteRecord),
317326
writeFunc: w.client.DeleteRecords,
318327

319328
flush: make(chan chan bool),
@@ -327,6 +336,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
327336
w.workersWaitGroup.Add(1)
328337
go w.deleteRecordWorker.run(ctx, &w.workersWaitGroup, tableName)
329338
w.deleteRecordWorker.ch <- m
339+
330340
return nil
331341
default:
332342
return fmt.Errorf("unhandled message type: %T", msg)
@@ -348,35 +358,40 @@ type streamingWorkerManager[T message.WriteMessage] struct {
348358
func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) {
349359
defer wg.Done()
350360
var (
351-
clientCh chan T
352-
clientErrCh chan error
353-
open bool
361+
inputCh chan T
362+
outputCh chan error
363+
open bool
354364
)
355365

356366
ensureOpened := func() {
357367
if open {
358368
return
359369
}
360370

361-
clientCh = make(chan T)
362-
clientErrCh = make(chan error, 1)
371+
inputCh = make(chan T)
372+
outputCh = make(chan error)
363373
go func() {
364-
defer close(clientErrCh)
374+
defer close(outputCh)
365375
defer func() {
366-
if err := recover(); err != nil {
367-
clientErrCh <- fmt.Errorf("panic: %v", err)
376+
if msg := recover(); msg != nil {
377+
switch v := msg.(type) {
378+
case error:
379+
outputCh <- fmt.Errorf("panic: %w [recovered]", v)
380+
default:
381+
outputCh <- fmt.Errorf("panic: %v [recovered]", msg)
382+
}
368383
}
369384
}()
370-
clientErrCh <- s.writeFunc(ctx, clientCh)
385+
result := s.writeFunc(ctx, inputCh)
386+
outputCh <- result
371387
}()
388+
372389
open = true
373390
}
391+
374392
closeFlush := func() {
375393
if open {
376-
close(clientCh)
377-
if err := <-clientErrCh; err != nil {
378-
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
379-
}
394+
close(inputCh)
380395
s.limit.Reset()
381396
}
382397
open = false
@@ -400,7 +415,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
400415
if add != nil {
401416
ensureOpened()
402417
s.limit.AddSlice(add)
403-
clientCh <- any(&message.WriteInsert{Record: add.Record}).(T)
418+
inputCh <- any(&message.WriteInsert{Record: add.Record}).(T)
404419
}
405420
if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() {
406421
// flush current batch
@@ -410,7 +425,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
410425
for _, sliceToFlush := range toFlush {
411426
ensureOpened()
412427
s.limit.AddRows(sliceToFlush.NumRows())
413-
clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
428+
inputCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
414429
closeFlush()
415430
ticker.Reset(s.batchTimeout)
416431
}
@@ -419,11 +434,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
419434
if rest != nil {
420435
ensureOpened()
421436
s.limit.AddSlice(rest)
422-
clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
437+
inputCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
423438
}
424439
} else {
425440
ensureOpened()
426-
clientCh <- r
441+
inputCh <- r
427442
s.limit.AddRows(1)
428443
if s.limit.ReachedLimit() {
429444
closeFlush()
@@ -441,6 +456,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
441456
ticker.Reset(s.batchTimeout)
442457
}
443458
done <- true
459+
case err := <-outputCh:
460+
if err != nil {
461+
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
462+
return
463+
}
444464
case <-ctxDone:
445465
// this means the request was cancelled
446466
return // after this NO other call will succeed

writers/streamingbatchwriter/streamingbatchwriter_test.go

+29-22
Original file line numberDiff line numberDiff line change
@@ -201,20 +201,30 @@ func TestStreamingBatchSizeRows(t *testing.T) {
201201
ch <- &message.WriteInsert{
202202
Record: record,
203203
}
204-
time.Sleep(50 * time.Millisecond)
205204

206-
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
207-
t.Fatalf("expected 0 insert messages, got %d", l)
208-
}
205+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
206+
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
209207

210208
ch <- &message.WriteInsert{
211209
Record: record,
212210
}
213-
ch <- &message.WriteInsert{ // third message, because we flush before exceeding the limit and then save the third one
211+
212+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
213+
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)
214+
215+
ch <- &message.WriteInsert{
214216
Record: record,
215217
}
216218

217219
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
220+
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
221+
222+
ch <- &message.WriteInsert{
223+
Record: record,
224+
}
225+
226+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 4)
227+
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)
218228

219229
close(ch)
220230
if err := <-errCh; err != nil {
@@ -225,7 +235,7 @@ func TestStreamingBatchSizeRows(t *testing.T) {
225235
t.Fatalf("expected 0 open tables, got %d", l)
226236
}
227237

228-
if l := testClient.MessageLen(messageTypeInsert); l != 3 {
238+
if l := testClient.MessageLen(messageTypeInsert); l != 4 {
229239
t.Fatalf("expected 3 insert messages, got %d", l)
230240
}
231241
}
@@ -253,18 +263,12 @@ func TestStreamingBatchTimeout(t *testing.T) {
253263
ch <- &message.WriteInsert{
254264
Record: record,
255265
}
256-
time.Sleep(50 * time.Millisecond)
257266

258-
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
259-
t.Fatalf("expected 0 insert messages, got %d", l)
260-
}
267+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
261268

262-
// we need to wait for the batch to be flushed
263-
time.Sleep(time.Millisecond * 50)
269+
time.Sleep(time.Millisecond * 50) // we need to wait for the batch to be flushed
264270

265-
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
266-
t.Fatalf("expected 0 insert messages, got %d", l)
267-
}
271+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
268272

269273
// flush
270274
tickFn()
@@ -301,32 +305,35 @@ func TestStreamingBatchNoTimeout(t *testing.T) {
301305
ch <- &message.WriteInsert{
302306
Record: record,
303307
}
304-
time.Sleep(50 * time.Millisecond)
305308

306-
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
307-
t.Fatalf("expected 0 insert messages, got %d", l)
308-
}
309+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
310+
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
309311

310312
time.Sleep(2 * time.Second)
311313

312-
if l := testClient.MessageLen(messageTypeInsert); l != 0 {
313-
t.Fatalf("expected 0 insert messages, got %d", l)
314-
}
314+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
315+
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
315316

316317
ch <- &message.WriteInsert{
317318
Record: record,
318319
}
320+
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
321+
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)
322+
319323
ch <- &message.WriteInsert{
320324
Record: record,
321325
}
322326

323327
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
328+
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)
324329

325330
close(ch)
326331
if err := <-errCh; err != nil {
327332
t.Fatal(err)
328333
}
329334

335+
time.Sleep(50 * time.Millisecond)
336+
330337
if l := testClient.OpenLen(messageTypeInsert); l != 0 {
331338
t.Fatalf("expected 0 open tables, got %d", l)
332339
}

0 commit comments

Comments
 (0)