Skip to content

Commit

Permalink
Revert "fix: Revert "fix: Error handling in StreamingBatchWriter" (#1918
Browse files Browse the repository at this point in the history
)"

This reverts commit 38b4bfd.
  • Loading branch information
disq committed Oct 3, 2024
1 parent 38b4bfd commit 9caa659
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 76 deletions.
128 changes: 74 additions & 54 deletions writers/streamingbatchwriter/streamingbatchwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,26 +182,28 @@ func (w *StreamingBatchWriter) Write(ctx context.Context, msgs <-chan message.Wr
errCh := make(chan error)
defer close(errCh)

go func() {
for err := range errCh {
w.logger.Err(err).Msg("error from StreamingBatchWriter")
}
}()
for {
select {
case msg, ok := <-msgs:
if !ok {
return w.Close(ctx)
}

for msg := range msgs {
msgType := writers.MsgID(msg)
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
if err := w.Flush(ctx); err != nil {
msgType := writers.MsgID(msg)
if w.lastMsgType != writers.MsgTypeUnset && w.lastMsgType != msgType {
if err := w.Flush(ctx); err != nil {
return err
}
}
w.lastMsgType = msgType
if err := w.startWorker(ctx, errCh, msg); err != nil {
return err
}
}
w.lastMsgType = msgType
if err := w.startWorker(ctx, errCh, msg); err != nil {

case err := <-errCh:
return err
}
}

return w.Close(ctx)
}

func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- error, msg message.WriteMessage) error {
Expand All @@ -221,13 +223,14 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
case *message.WriteMigrateTable:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.migrateWorker != nil {
w.migrateWorker.ch <- m
return nil
}
ch := make(chan *message.WriteMigrateTable)

w.migrateWorker = &streamingWorkerManager[*message.WriteMigrateTable]{
ch: ch,
ch: make(chan *message.WriteMigrateTable),
writeFunc: w.client.MigrateTable,

flush: make(chan chan bool),
Expand All @@ -241,17 +244,19 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.migrateWorker.run(ctx, &w.workersWaitGroup, tableName)
w.migrateWorker.ch <- m

return nil
case *message.WriteDeleteStale:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.deleteStaleWorker != nil {
w.deleteStaleWorker.ch <- m
return nil
}
ch := make(chan *message.WriteDeleteStale)

w.deleteStaleWorker = &streamingWorkerManager[*message.WriteDeleteStale]{
ch: ch,
ch: make(chan *message.WriteDeleteStale),
writeFunc: w.client.DeleteStale,

flush: make(chan chan bool),
Expand All @@ -265,19 +270,29 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.deleteStaleWorker.run(ctx, &w.workersWaitGroup, tableName)
w.deleteStaleWorker.ch <- m

return nil
case *message.WriteInsert:
w.workersLock.RLock()
wr, ok := w.insertWorkers[tableName]
worker, ok := w.insertWorkers[tableName]
w.workersLock.RUnlock()
if ok {
wr.ch <- m
worker.ch <- m
return nil
}

ch := make(chan *message.WriteInsert)
wr = &streamingWorkerManager[*message.WriteInsert]{
ch: ch,
w.workersLock.Lock()
activeWorker, ok := w.insertWorkers[tableName]
if ok {
w.workersLock.Unlock()
// some other goroutine could have already added the worker
// just send the message to it & discard our allocated worker
activeWorker.ch <- m
return nil
}

worker = &streamingWorkerManager[*message.WriteInsert]{
ch: make(chan *message.WriteInsert),
writeFunc: w.client.WriteTable,

flush: make(chan chan bool),
Expand All @@ -287,33 +302,27 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
batchTimeout: w.batchTimeout,
tickerFn: w.tickerFn,
}
w.workersLock.Lock()
wrOld, ok := w.insertWorkers[tableName]
if ok {
w.workersLock.Unlock()
// some other goroutine could have already added the worker
// just send the message to it & discard our allocated worker
wrOld.ch <- m
return nil
}
w.insertWorkers[tableName] = wr

w.insertWorkers[tableName] = worker
w.workersLock.Unlock()

w.workersWaitGroup.Add(1)
go wr.run(ctx, &w.workersWaitGroup, tableName)
ch <- m
go worker.run(ctx, &w.workersWaitGroup, tableName)
worker.ch <- m

return nil
case *message.WriteDeleteRecord:
w.workersLock.Lock()
defer w.workersLock.Unlock()

if w.deleteRecordWorker != nil {
w.deleteRecordWorker.ch <- m
return nil
}
ch := make(chan *message.WriteDeleteRecord)

// TODO: flush all workers for nested tables as well (See https://github.com/cloudquery/plugin-sdk/issues/1296)
w.deleteRecordWorker = &streamingWorkerManager[*message.WriteDeleteRecord]{
ch: ch,
ch: make(chan *message.WriteDeleteRecord),
writeFunc: w.client.DeleteRecords,

flush: make(chan chan bool),
Expand All @@ -327,6 +336,7 @@ func (w *StreamingBatchWriter) startWorker(ctx context.Context, errCh chan<- err
w.workersWaitGroup.Add(1)
go w.deleteRecordWorker.run(ctx, &w.workersWaitGroup, tableName)
w.deleteRecordWorker.ch <- m

return nil
default:
return fmt.Errorf("unhandled message type: %T", msg)
Expand All @@ -348,35 +358,40 @@ type streamingWorkerManager[T message.WriteMessage] struct {
func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup, tableName string) {
defer wg.Done()
var (
clientCh chan T
clientErrCh chan error
open bool
inputCh chan T
outputCh chan error
open bool
)

ensureOpened := func() {
if open {
return
}

clientCh = make(chan T)
clientErrCh = make(chan error, 1)
inputCh = make(chan T)
outputCh = make(chan error)
go func() {
defer close(clientErrCh)
defer close(outputCh)
defer func() {
if err := recover(); err != nil {
clientErrCh <- fmt.Errorf("panic: %v", err)
if msg := recover(); msg != nil {
switch v := msg.(type) {
case error:
outputCh <- fmt.Errorf("panic: %w [recovered]", v)
default:
outputCh <- fmt.Errorf("panic: %v [recovered]", msg)
}
}
}()
clientErrCh <- s.writeFunc(ctx, clientCh)
result := s.writeFunc(ctx, inputCh)
outputCh <- result
}()

open = true
}

closeFlush := func() {
if open {
close(clientCh)
if err := <-clientErrCh; err != nil {
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
}
close(inputCh)
s.limit.Reset()
}
open = false
Expand All @@ -400,7 +415,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
if add != nil {
ensureOpened()
s.limit.AddSlice(add)
clientCh <- any(&message.WriteInsert{Record: add.Record}).(T)
inputCh <- any(&message.WriteInsert{Record: add.Record}).(T)
}
if len(toFlush) > 0 || rest != nil || s.limit.ReachedLimit() {
// flush current batch
Expand All @@ -410,7 +425,7 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
for _, sliceToFlush := range toFlush {
ensureOpened()
s.limit.AddRows(sliceToFlush.NumRows())
clientCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
inputCh <- any(&message.WriteInsert{Record: sliceToFlush}).(T)
closeFlush()
ticker.Reset(s.batchTimeout)
}
Expand All @@ -419,11 +434,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
if rest != nil {
ensureOpened()
s.limit.AddSlice(rest)
clientCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
inputCh <- any(&message.WriteInsert{Record: rest.Record}).(T)
}
} else {
ensureOpened()
clientCh <- r
inputCh <- r
s.limit.AddRows(1)
if s.limit.ReachedLimit() {
closeFlush()
Expand All @@ -441,6 +456,11 @@ func (s *streamingWorkerManager[T]) run(ctx context.Context, wg *sync.WaitGroup,
ticker.Reset(s.batchTimeout)
}
done <- true
case err := <-outputCh:
if err != nil {
s.errCh <- fmt.Errorf("handler failed on %s: %w", tableName, err)
return
}
case <-ctxDone:
// this means the request was cancelled
return // after this NO other call will succeed
Expand Down
51 changes: 29 additions & 22 deletions writers/streamingbatchwriter/streamingbatchwriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,20 +201,30 @@ func TestStreamingBatchSizeRows(t *testing.T) {
ch <- &message.WriteInsert{
Record: record,
}
time.Sleep(50 * time.Millisecond)

if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)

ch <- &message.WriteInsert{
Record: record,
}
ch <- &message.WriteInsert{ // third message, because we flush before exceeding the limit and then save the third one

waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)

ch <- &message.WriteInsert{
Record: record,
}

waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)

ch <- &message.WriteInsert{
Record: record,
}

waitForLength(t, testClient.MessageLen, messageTypeInsert, 4)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)

close(ch)
if err := <-errCh; err != nil {
Expand All @@ -225,7 +235,7 @@ func TestStreamingBatchSizeRows(t *testing.T) {
t.Fatalf("expected 0 open tables, got %d", l)
}

if l := testClient.MessageLen(messageTypeInsert); l != 3 {
if l := testClient.MessageLen(messageTypeInsert); l != 4 {
t.Fatalf("expected 3 insert messages, got %d", l)
}
}
Expand Down Expand Up @@ -253,18 +263,12 @@ func TestStreamingBatchTimeout(t *testing.T) {
ch <- &message.WriteInsert{
Record: record,
}
time.Sleep(50 * time.Millisecond)

if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)

// we need to wait for the batch to be flushed
time.Sleep(time.Millisecond * 50)
time.Sleep(time.Millisecond * 50) // we need to wait for the batch to be flushed

if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)

// flush
tickFn()
Expand Down Expand Up @@ -301,32 +305,35 @@ func TestStreamingBatchNoTimeout(t *testing.T) {
ch <- &message.WriteInsert{
Record: record,
}
time.Sleep(50 * time.Millisecond)

if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)

time.Sleep(2 * time.Second)

if l := testClient.MessageLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 insert messages, got %d", l)
}
waitForLength(t, testClient.MessageLen, messageTypeInsert, 0)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)

ch <- &message.WriteInsert{
Record: record,
}
waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 0)

ch <- &message.WriteInsert{
Record: record,
}

waitForLength(t, testClient.MessageLen, messageTypeInsert, 2)
waitForLength(t, testClient.InflightLen, messageTypeInsert, 1)

close(ch)
if err := <-errCh; err != nil {
t.Fatal(err)
}

time.Sleep(50 * time.Millisecond)

if l := testClient.OpenLen(messageTypeInsert); l != 0 {
t.Fatalf("expected 0 open tables, got %d", l)
}
Expand Down

0 comments on commit 9caa659

Please sign in to comment.