@@ -182,26 +182,28 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
182
182
errCh := make (chan error )
183
183
defer close (errCh )
184
184
185
- go func () {
186
- for err := range errCh {
187
- w .logger .Err (err ).Msg ("error from StreamingBatchWriter" )
188
- }
189
- }()
185
+ for {
186
+ select {
187
+ case msg , ok := <- msgs :
188
+ if ! ok {
189
+ return w .Close (ctx )
190
+ }
190
191
191
- for msg := range msgs {
192
- msgType := writers .MsgID (msg )
193
- if w .lastMsgType != writers .MsgTypeUnset && w .lastMsgType != msgType {
194
- if err := w .Flush (ctx ); err != nil {
192
+ msgType := writers .MsgID (msg )
193
+ if w .lastMsgType != writers .MsgTypeUnset && w .lastMsgType != msgType {
194
+ if err := w .Flush (ctx ); err != nil {
195
+ return err
196
+ }
197
+ }
198
+ w .lastMsgType = msgType
199
+ if err := w .startWorker (ctx , errCh , msg ); err != nil {
195
200
return err
196
201
}
197
- }
198
- w .lastMsgType = msgType
199
- if err := w .startWorker (ctx , errCh , msg ); err != nil {
202
+
203
+ case err := <- errCh :
200
204
return err
201
205
}
202
206
}
203
-
204
- return w .Close (ctx )
205
207
}
206
208
207
209
func (w * StreamingBatchWriter ) startWorker (ctx context.Context , errCh chan <- error , msg message.WriteMessage ) error {
@@ -221,13 +223,14 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
221
223
case * message.WriteMigrateTable :
222
224
w .workersLock .Lock ()
223
225
defer w .workersLock .Unlock ()
226
+
224
227
if w .migrateWorker != nil {
225
228
w .migrateWorker .ch <- m
226
229
return nil
227
230
}
228
- ch := make ( chan * message. WriteMigrateTable )
231
+
229
232
w .migrateWorker = & streamingWorkerManager [* message.WriteMigrateTable ]{
230
- ch : ch ,
233
+ ch : make ( chan * message. WriteMigrateTable ) ,
231
234
writeFunc : w .client .MigrateTable ,
232
235
233
236
flush : make (chan chan bool ),
@@ -241,17 +244,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
241
244
w .workersWaitGroup .Add (1 )
242
245
go w .migrateWorker .run (ctx , & w .workersWaitGroup , tableName )
243
246
w .migrateWorker .ch <- m
247
+
244
248
return nil
245
249
case * message.WriteDeleteStale :
246
250
w .workersLock .Lock ()
247
251
defer w .workersLock .Unlock ()
252
+
248
253
if w .deleteStaleWorker != nil {
249
254
w .deleteStaleWorker .ch <- m
250
255
return nil
251
256
}
252
- ch := make ( chan * message. WriteDeleteStale )
257
+
253
258
w .deleteStaleWorker = & streamingWorkerManager [* message.WriteDeleteStale ]{
254
- ch : ch ,
259
+ ch : make ( chan * message. WriteDeleteStale ) ,
255
260
writeFunc : w .client .DeleteStale ,
256
261
257
262
flush : make (chan chan bool ),
@@ -265,19 +270,29 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
265
270
w .workersWaitGroup .Add (1 )
266
271
go w .deleteStaleWorker .run (ctx , & w .workersWaitGroup , tableName )
267
272
w .deleteStaleWorker .ch <- m
273
+
268
274
return nil
269
275
case * message.WriteInsert :
270
276
w .workersLock .RLock ()
271
- wr , ok := w .insertWorkers [tableName ]
277
+ worker , ok := w .insertWorkers [tableName ]
272
278
w .workersLock .RUnlock ()
273
279
if ok {
274
- wr .ch <- m
280
+ worker .ch <- m
275
281
return nil
276
282
}
277
283
278
- ch := make (chan * message.WriteInsert )
279
- wr = & streamingWorkerManager [* message.WriteInsert ]{
280
- ch : ch ,
284
+ w .workersLock .Lock ()
285
+ activeWorker , ok := w .insertWorkers [tableName ]
286
+ if ok {
287
+ w .workersLock .Unlock ()
288
+ // some other goroutine could have already added the worker
289
+ // just send the message to it & discard our allocated worker
290
+ activeWorker .ch <- m
291
+ return nil
292
+ }
293
+
294
+ worker = & streamingWorkerManager [* message.WriteInsert ]{
295
+ ch : make (chan * message.WriteInsert ),
281
296
writeFunc : w .client .WriteTable ,
282
297
283
298
flush : make (chan chan bool ),
@@ -287,33 +302,27 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
287
302
batchTimeout : w .batchTimeout ,
288
303
tickerFn : w .tickerFn ,
289
304
}
290
- w .workersLock .Lock ()
291
- wrOld , ok := w .insertWorkers [tableName ]
292
- if ok {
293
- w .workersLock .Unlock ()
294
- // some other goroutine could have already added the worker
295
- // just send the message to it & discard our allocated worker
296
- wrOld .ch <- m
297
- return nil
298
- }
299
- w .insertWorkers [tableName ] = wr
305
+
306
+ w .insertWorkers [tableName ] = worker
300
307
w .workersLock .Unlock ()
301
308
302
309
w .workersWaitGroup .Add (1 )
303
- go wr .run (ctx , & w .workersWaitGroup , tableName )
304
- ch <- m
310
+ go worker .run (ctx , & w .workersWaitGroup , tableName )
311
+ worker .ch <- m
312
+
305
313
return nil
306
314
case * message.WriteDeleteRecord :
307
315
w .workersLock .Lock ()
308
316
defer w .workersLock .Unlock ()
317
+
309
318
if w .deleteRecordWorker != nil {
310
319
w .deleteRecordWorker .ch <- m
311
320
return nil
312
321
}
313
- ch := make ( chan * message. WriteDeleteRecord )
322
+
314
323
// TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
315
324
w .deleteRecordWorker = & streamingWorkerManager [* message.WriteDeleteRecord ]{
316
- ch : ch ,
325
+ ch : make ( chan * message. WriteDeleteRecord ) ,
317
326
writeFunc : w .client .DeleteRecords ,
318
327
319
328
flush : make (chan chan bool ),
@@ -327,6 +336,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
327
336
w .workersWaitGroup .Add (1 )
328
337
go w .deleteRecordWorker .run (ctx , & w .workersWaitGroup , tableName )
329
338
w .deleteRecordWorker .ch <- m
339
+
330
340
return nil
331
341
default :
332
342
return fmt .Errorf ("unhandled message type: %T" , msg )
@@ -348,35 +358,40 @@ type streamingWorkerManager[T message.WriteMessage] struct {
348
358
func (s * streamingWorkerManager [T ]) run (ctx context.Context , wg * sync.WaitGroup , tableName string ) {
349
359
defer wg .Done ()
350
360
var (
351
- clientCh chan T
352
- clientErrCh chan error
353
- open bool
361
+ inputCh chan T
362
+ outputCh chan error
363
+ open bool
354
364
)
355
365
356
366
ensureOpened := func () {
357
367
if open {
358
368
return
359
369
}
360
370
361
- clientCh = make (chan T )
362
- clientErrCh = make (chan error , 1 )
371
+ inputCh = make (chan T )
372
+ outputCh = make (chan error )
363
373
go func () {
364
- defer close (clientErrCh )
374
+ defer close (outputCh )
365
375
defer func () {
366
- if err := recover (); err != nil {
367
- clientErrCh <- fmt .Errorf ("panic: %v" , err )
376
+ if msg := recover (); msg != nil {
377
+ switch v := msg .(type ) {
378
+ case error :
379
+ outputCh <- fmt .Errorf ("panic: %w [recovered]" , v )
380
+ default :
381
+ outputCh <- fmt .Errorf ("panic: %v [recovered]" , msg )
382
+ }
368
383
}
369
384
}()
370
- clientErrCh <- s .writeFunc (ctx , clientCh )
385
+ result := s .writeFunc (ctx , inputCh )
386
+ outputCh <- result
371
387
}()
388
+
372
389
open = true
373
390
}
391
+
374
392
closeFlush := func () {
375
393
if open {
376
- close (clientCh )
377
- if err := <- clientErrCh ; err != nil {
378
- s .errCh <- fmt .Errorf ("handler failed on %s: %w" , tableName , err )
379
- }
394
+ close (inputCh )
380
395
s .limit .Reset ()
381
396
}
382
397
open = false
@@ -400,7 +415,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
400
415
if add != nil {
401
416
ensureOpened ()
402
417
s .limit .AddSlice (add )
403
- clientCh <- any (& message.WriteInsert {Record : add .Record }).(T )
418
+ inputCh <- any (& message.WriteInsert {Record : add .Record }).(T )
404
419
}
405
420
if len (toFlush ) > 0 || rest != nil || s .limit .ReachedLimit () {
406
421
// flush current batch
@@ -410,7 +425,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
410
425
for _ , sliceToFlush := range toFlush {
411
426
ensureOpened ()
412
427
s .limit .AddRows (sliceToFlush .NumRows ())
413
- clientCh <- any (& message.WriteInsert {Record : sliceToFlush }).(T )
428
+ inputCh <- any (& message.WriteInsert {Record : sliceToFlush }).(T )
414
429
closeFlush ()
415
430
ticker .Reset (s .batchTimeout )
416
431
}
@@ -419,11 +434,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
419
434
if rest != nil {
420
435
ensureOpened ()
421
436
s .limit .AddSlice (rest )
422
- clientCh <- any (& message.WriteInsert {Record : rest .Record }).(T )
437
+ inputCh <- any (& message.WriteInsert {Record : rest .Record }).(T )
423
438
}
424
439
} else {
425
440
ensureOpened ()
426
- clientCh <- r
441
+ inputCh <- r
427
442
s .limit .AddRows (1 )
428
443
if s .limit .ReachedLimit () {
429
444
closeFlush ()
@@ -441,6 +456,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
441
456
ticker .Reset (s .batchTimeout )
442
457
}
443
458
done <- true
459
+ case err := <- outputCh :
460
+ if err != nil {
461
+ s .errCh <- fmt .Errorf ("handler failed on %s: %w" , tableName , err )
462
+ return
463
+ }
444
464
case <- ctxDone :
445
465
// this means the request was cancelled
446
466
return // after this NO other call will succeed
0 commit comments