diff --git a/writers/streamingbatchwriter/streamingbatchwriter.go b/writers/streamingbatchwriter/streamingbatchwriter.go index d5825df31d..23b6e8c139 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter.go +++ b/writers/streamingbatchwriter/streamingbatchwriter.go @@ -182,28 +182,26 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr errCh := make(chan error) defer close(errCh) - for { - select { - case msg, ok := <-msgs: - if !ok { - return w.Close(ctx) - } + go func() { + for err := range errCh { + w.logger.Err(err).Msg("error from StreamingBatchWriter") + } + }() - msgType := writers.MsgID(msg) - if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType { - if err := w.Flush(ctx); err != nil { - return err - } - } - w.lastMsgType = msgType - if err := w.startWorker(ctx, errCh, msg); err != nil { + for msg := range msgs { + msgType := writers.MsgID(msg) + if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType { + if err := w.Flush(ctx); err != nil { return err } - - case err := <-errCh: + } + w.lastMsgType = msgType + if err := w.startWorker(ctx, errCh, msg); err != nil { return err } } + + return w.Close(ctx) } func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error { @@ -223,14 +221,13 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err case *message.WriteMigrateTable: w.workersLock.Lock() defer w.workersLock.Unlock() - if w.migrateWorker != nil { w.migrateWorker.ch <- m return nil } - + ch := make(chan *message.WriteMigrateTable) w.migrateWorker = &streamingWorkerManager[*message.WriteMigrateTable]{ - ch: make(chan *message.WriteMigrateTable), + ch: ch, writeFunc: w.client.MigrateTable, flush: make(chan chan bool), @@ -244,19 +241,17 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err w.workersWaitGroup.Add(1) go w.migrateWorker.run(ctx, &w.workersWaitGroup, tableName) w.migrateWorker.ch <- m - return nil case *message.WriteDeleteStale: w.workersLock.Lock() defer w.workersLock.Unlock() - if w.deleteStaleWorker != nil { w.deleteStaleWorker.ch <- m return nil } - + ch := make(chan *message.WriteDeleteStale) w.deleteStaleWorker = &streamingWorkerManager[*message.WriteDeleteStale]{ - ch: make(chan *message.WriteDeleteStale), + ch: ch, writeFunc: w.client.DeleteStale, flush: make(chan chan bool), @@ -270,29 +265,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err w.workersWaitGroup.Add(1) go w.deleteStaleWorker.run(ctx, &w.workersWaitGroup, tableName) w.deleteStaleWorker.ch <- m - return nil case *message.WriteInsert: w.workersLock.RLock() - worker, ok := w.insertWorkers[tableName] + wr, ok := w.insertWorkers[tableName] w.workersLock.RUnlock() if ok { - worker.ch <- m + wr.ch <- m return nil } - w.workersLock.Lock() - activeWorker, ok := w.insertWorkers[tableName] - if ok { - w.workersLock.Unlock() - // some other goroutine could have already added the worker - // just send the message to it & discard our allocated worker - activeWorker.ch <- m - return nil - } - - worker = &streamingWorkerManager[*message.WriteInsert]{ - ch: make(chan *message.WriteInsert), + ch := make(chan *message.WriteInsert) + wr = &streamingWorkerManager[*message.WriteInsert]{ + ch: ch, writeFunc: w.client.WriteTable, flush: make(chan chan bool), @@ -302,27 +287,33 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err batchTimeout: w.batchTimeout, tickerFn: w.tickerFn, } - - w.insertWorkers[tableName] = worker + w.workersLock.Lock() + wrOld, ok := w.insertWorkers[tableName] + if ok { + w.workersLock.Unlock() + // some other goroutine could have already added the worker + // just send the message to it & discard our allocated worker + wrOld.ch <- m + return nil + } + w.insertWorkers[tableName] = wr w.workersLock.Unlock() w.workersWaitGroup.Add(1) - go worker.run(ctx, &w.workersWaitGroup, tableName) - worker.ch <- m - + go wr.run(ctx, &w.workersWaitGroup, tableName) + ch <- m return nil case *message.WriteDeleteRecord: w.workersLock.Lock() defer w.workersLock.Unlock() - if w.deleteRecordWorker != nil { w.deleteRecordWorker.ch <- m return nil } - + ch := make(chan *message.WriteDeleteRecord) // TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296) w.deleteRecordWorker = &streamingWorkerManager[*message.WriteDeleteRecord]{ - ch: make(chan *message.WriteDeleteRecord), + ch: ch, writeFunc: w.client.DeleteRecords, flush: make(chan chan bool), @@ -336,7 +327,6 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err w.workersWaitGroup.Add(1) go w.deleteRecordWorker.run(ctx, &w.workersWaitGroup, tableName) w.deleteRecordWorker.ch <- m - return nil default: return fmt.Errorf("unhandled message type: %T", msg) @@ -358,9 +348,9 @@ type streamingWorkerManager[T message.WriteMessage] struct { func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) { defer wg.Done() var ( - inputCh chan T - outputCh chan error - open bool + clientCh chan T + clientErrCh chan error + open bool ) ensureOpened := func() { @@ -368,30 +358,25 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, return } - inputCh = make(chan T) - outputCh = make(chan error) + clientCh = make(chan T) + clientErrCh = make(chan error, 1) go func() { - defer close(outputCh) + defer close(clientErrCh) defer func() { - if msg := recover(); msg != nil { - switch v := msg.(type) { - case error: - outputCh <- fmt.Errorf("panic: %w [recovered]", v) - default: - outputCh <- fmt.Errorf("panic: %v [recovered]", msg) - } + if err := recover(); err != nil { + clientErrCh <- fmt.Errorf("panic: %v", err) } }() - result := s.writeFunc(ctx, inputCh) - outputCh <- result + clientErrCh <- s.writeFunc(ctx, clientCh) }() - open = true } - closeFlush := func() { if open { - close(inputCh) + close(clientCh) + if err := <-clientErrCh; err != nil { + s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err) + } s.limit.Reset() } open = false @@ -415,7 +400,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, if add != nil { ensureOpened() s.limit.AddSlice(add) - inputCh <- any(&message.WriteInsert{Record: add.Record}).(T) + clientCh <- any(&message.WriteInsert{Record: add.Record}).(T) } if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() { // flush current batch @@ -425,7 +410,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, for _, sliceToFlush := range toFlush { ensureOpened() s.limit.AddRows(sliceToFlush.NumRows()) - inputCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T) + clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T) closeFlush() ticker.Reset(s.batchTimeout) } @@ -434,11 +419,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, if rest != nil { ensureOpened() s.limit.AddSlice(rest) - inputCh <- any(&message.WriteInsert{Record: rest.Record}).(T) + clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T) } } else { ensureOpened() - inputCh <- r + clientCh <- r s.limit.AddRows(1) if s.limit.ReachedLimit() { closeFlush() @@ -456,11 +441,6 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, ticker.Reset(s.batchTimeout) } done <- true - case err := <-outputCh: - if err != nil { - s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err) - return - } case <-ctxDone: // this means the request was cancelled return // after this NO other call will succeed diff --git a/writers/streamingbatchwriter/streamingbatchwriter_test.go b/writers/streamingbatchwriter/streamingbatchwriter_test.go index 7e6703c92d..08cabbfd1a 100644 --- a/writers/streamingbatchwriter/streamingbatchwriter_test.go +++ b/writers/streamingbatchwriter/streamingbatchwriter_test.go @@ -201,30 +201,20 @@ func TestStreamingBatchSizeRows(t *testing.T) { ch <- &message.WriteInsert{ Record: record, } + time.Sleep(50 * time.Millisecond) - waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) - waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) - - ch <- &message.WriteInsert{ - Record: record, + if l := testClient.MessageLen(messageTypeInsert); l != 0 { + t.Fatalf("expected 0 insert messages, got %d", l) } - waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) - waitForLength(t, testClient.InflightLen, messageTypeInsert, 0) - ch <- &message.WriteInsert{ Record: record, } - - waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) - waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) - - ch <- &message.WriteInsert{ + ch <- &message.WriteInsert{ // third message, because we flush before exceeding the limit and then save the third one Record: record, } - waitForLength(t, testClient.MessageLen, messageTypeInsert, 4) - waitForLength(t, testClient.InflightLen, messageTypeInsert, 0) + waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) close(ch) if err := <-errCh; err != nil { @@ -235,7 +225,7 @@ func TestStreamingBatchSizeRows(t *testing.T) { t.Fatalf("expected 0 open tables, got %d", l) } - if l := testClient.MessageLen(messageTypeInsert); l != 4 { + if l := testClient.MessageLen(messageTypeInsert); l != 3 { t.Fatalf("expected 3 insert messages, got %d", l) } } @@ -263,12 +253,18 @@ func TestStreamingBatchTimeout(t *testing.T) { ch <- &message.WriteInsert{ Record: record, } + time.Sleep(50 * time.Millisecond) - waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) + if l := testClient.MessageLen(messageTypeInsert); l != 0 { + t.Fatalf("expected 0 insert messages, got %d", l) + } - time.Sleep(time.Millisecond * 50) // we need to wait for the batch to be flushed + // we need to wait for the batch to be flushed + time.Sleep(time.Millisecond * 50) - waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) + if l := testClient.MessageLen(messageTypeInsert); l != 0 { + t.Fatalf("expected 0 insert messages, got %d", l) + } // flush tickFn() @@ -305,35 +301,32 @@ func TestStreamingBatchNoTimeout(t *testing.T) { ch <- &message.WriteInsert{ Record: record, } + time.Sleep(50 * time.Millisecond) - waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) - waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) + if l := testClient.MessageLen(messageTypeInsert); l != 0 { + t.Fatalf("expected 0 insert messages, got %d", l) + } time.Sleep(2 * time.Second) - waitForLength(t, testClient.MessageLen, messageTypeInsert, 0) - waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) + if l := testClient.MessageLen(messageTypeInsert); l != 0 { + t.Fatalf("expected 0 insert messages, got %d", l) + } ch <- &message.WriteInsert{ Record: record, } - waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) - waitForLength(t, testClient.InflightLen, messageTypeInsert, 0) - ch <- &message.WriteInsert{ Record: record, } waitForLength(t, testClient.MessageLen, messageTypeInsert, 2) - waitForLength(t, testClient.InflightLen, messageTypeInsert, 1) close(ch) if err := <-errCh; err != nil { t.Fatal(err) } - time.Sleep(50 * time.Millisecond) - if l := testClient.OpenLen(messageTypeInsert); l != 0 { t.Fatalf("expected 0 open tables, got %d", l) }