diff --git a/br/pkg/streamhelper/BUILD.bazel b/br/pkg/streamhelper/BUILD.bazel index a2c0080287e1b..0a58ed8eb5b40 100644 --- a/br/pkg/streamhelper/BUILD.bazel +++ b/br/pkg/streamhelper/BUILD.bazel @@ -68,7 +68,7 @@ go_test( ], flaky = True, race = "on", - shard_count = 26, + shard_count = 27, deps = [ ":streamhelper", "//br/pkg/errors", diff --git a/br/pkg/streamhelper/advancer.go b/br/pkg/streamhelper/advancer.go index c9b23c207e41f..ffcf4ed0499db 100644 --- a/br/pkg/streamhelper/advancer.go +++ b/br/pkg/streamhelper/advancer.go @@ -697,3 +697,12 @@ func (c *CheckpointAdvancer) asyncResolveLocksForRanges(ctx context.Context, tar c.inResolvingLock.Store(false) }() } + +func (c *CheckpointAdvancer) TEST_registerCallbackForSubscriptions(f func()) int { + cnt := 0 + for _, sub := range c.subscriber.subscriptions { + sub.onDaemonExit = f + cnt += 1 + } + return cnt +} diff --git a/br/pkg/streamhelper/advancer_test.go b/br/pkg/streamhelper/advancer_test.go index 55abd5baa2c80..d3051a6ba57d3 100644 --- a/br/pkg/streamhelper/advancer_test.go +++ b/br/pkg/streamhelper/advancer_test.go @@ -548,3 +548,55 @@ func TestUnregisterAfterPause(t *testing.T) { return err != nil && strings.Contains(err.Error(), "check point lagged too large") }, 5*time.Second, 300*time.Millisecond) } + +func TestOwnershipLost(t *testing.T) { + c := createFakeCluster(t, 4, false) + c.splitAndScatter(manyRegions(0, 10240)...) + installSubscribeSupport(c) + ctx, cancel := context.WithCancel(context.Background()) + env := &testEnv{fakeCluster: c, testCtx: t} + adv := streamhelper.NewCheckpointAdvancer(env) + adv.OnStart(ctx) + adv.OnBecomeOwner(ctx) + require.NoError(t, adv.OnTick(ctx)) + c.advanceCheckpoints() + c.flushAll() + failpoint.Enable("github.com/pingcap/tidb/br/pkg/streamhelper/subscription.listenOver.aboutToSend", "pause") + failpoint.Enable("github.com/pingcap/tidb/br/pkg/streamhelper/FlushSubscriber.Clear.timeoutMs", "return(500)") + wg := new(sync.WaitGroup) + wg.Add(adv.TEST_registerCallbackForSubscriptions(wg.Done)) + cancel() + failpoint.Disable("github.com/pingcap/tidb/br/pkg/streamhelper/subscription.listenOver.aboutToSend") + wg.Wait() +} + +func TestSubscriptionPanic(t *testing.T) { + c := createFakeCluster(t, 4, false) + c.splitAndScatter(manyRegions(0, 20)...) + installSubscribeSupport(c) + ctx, cancel := context.WithCancel(context.Background()) + env := &testEnv{fakeCluster: c, testCtx: t} + adv := streamhelper.NewCheckpointAdvancer(env) + adv.OnStart(ctx) + adv.OnBecomeOwner(ctx) + wg := new(sync.WaitGroup) + wg.Add(adv.TEST_registerCallbackForSubscriptions(wg.Done)) + + require.NoError(t, adv.OnTick(ctx)) + failpoint.Enable("github.com/pingcap/tidb/br/pkg/streamhelper/subscription.listenOver.aboutToSend", "5*panic") + ckpt := c.advanceCheckpoints() + c.flushAll() + cnt := 0 + for { + require.NoError(t, adv.OnTick(ctx)) + cnt++ + if env.checkpoint >= ckpt { + break + } + if cnt > 100 { + t.Fatalf("After 100 times, the progress cannot be advanced.") + } + } + cancel() + wg.Wait() +} diff --git a/br/pkg/streamhelper/flush_subscriber.go b/br/pkg/streamhelper/flush_subscriber.go index 1a3f0a523d170..a73673b6f887c 100644 --- a/br/pkg/streamhelper/flush_subscriber.go +++ b/br/pkg/streamhelper/flush_subscriber.go @@ -11,8 +11,10 @@ import ( "github.com/google/uuid" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" logbackup "github.com/pingcap/kvproto/pkg/logbackuppb" "github.com/pingcap/log" + berrors "github.com/pingcap/tidb/br/pkg/errors" "github.com/pingcap/tidb/br/pkg/logutil" "github.com/pingcap/tidb/br/pkg/streamhelper/spans" "github.com/pingcap/tidb/metrics" @@ -23,6 +25,11 @@ import ( "google.golang.org/grpc/status" ) +const ( + // clearSubscriberTimeOut is the timeout for clearing the subscriber. + clearSubscriberTimeOut = 1 * time.Minute +) + // FlushSubscriber maintains the state of subscribing to the cluster. type FlushSubscriber struct { dialer LogBackupService @@ -86,7 +93,7 @@ func (f *FlushSubscriber) UpdateStoreTopology(ctx context.Context) error { for id := range f.subscriptions { _, ok := storeSet[id] if !ok { - f.removeSubscription(id) + f.removeSubscription(ctx, id) } } return nil @@ -94,9 +101,18 @@ func (f *FlushSubscriber) UpdateStoreTopology(ctx context.Context) error { // Clear clears all the subscriptions. func (f *FlushSubscriber) Clear() { - log.Info("[log backup flush subscriber] Clearing.") + timeout := clearSubscriberTimeOut + failpoint.Inject("FlushSubscriber.Clear.timeoutMs", func(v failpoint.Value) { + //nolint:durationcheck + timeout = time.Duration(v.(int)) * time.Millisecond + }) + log.Info("Clearing.", + zap.String("category", "log backup flush subscriber"), + zap.Duration("timeout", timeout)) + cx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() for id := range f.subscriptions { - f.removeSubscription(id) + f.removeSubscription(cx, id) } } @@ -132,15 +148,11 @@ type eventStream = logbackup.LogBackup_SubscribeFlushEventClient type joinHandle <-chan struct{} -func (jh joinHandle) WaitTimeOut(dur time.Duration) { - var t <-chan time.Time - if dur > 0 { - t = time.After(dur) - } +func (jh joinHandle) Wait(ctx context.Context) { select { case <-jh: - case <-t: - log.Warn("join handle timed out.") + case <-ctx.Done(): + log.Warn("join handle timed out.", zap.StackSkip("caller", 1)) } } @@ -171,6 +183,8 @@ type subscription struct { // we need to try reconnect even there is a error cannot be retry. storeBootAt uint64 output chan<- spans.Valued + + onDaemonExit func() } func (s *subscription) emitError(err error) { @@ -213,7 +227,7 @@ func (s *subscription) doConnect(ctx context.Context, dialer LogBackupService) e log.Info("[log backup subscription manager] Adding subscription.", zap.Uint64("store", s.storeID), zap.Uint64("boot", s.storeBootAt)) // We should shutdown the background task firstly. // Once it yields some error during shuting down, the error won't be brought to next run. - s.close() + s.close(ctx) s.clearError() c, err := dialer.GetLogBackupClient(ctx, s.storeID) @@ -236,10 +250,10 @@ func (s *subscription) doConnect(ctx context.Context, dialer LogBackupService) e return nil } -func (s *subscription) close() { +func (s *subscription) close(ctx context.Context) { if s.cancel != nil { s.cancel() - s.background.WaitTimeOut(1 * time.Minute) + s.background.Wait(ctx) } // HACK: don't close the internal channel here, // because it is a ever-sharing channel. @@ -248,6 +262,16 @@ func (s *subscription) close() { func (s *subscription) listenOver(ctx context.Context, cli eventStream) { storeID := s.storeID logutil.CL(ctx).Info("Listen starting.", zap.Uint64("store", storeID)) + defer func() { + if s.onDaemonExit != nil { + s.onDaemonExit() + } + + if pData := recover(); pData != nil { + log.Warn("Subscriber paniked.", zap.Uint64("store", storeID), zap.Any("panic-data", pData), zap.Stack("stack")) + s.emitError(errors.Annotatef(berrors.ErrUnknown, "panic during executing: %v", pData)) + } + }() for { // Shall we use RecvMsg for better performance? // Note that the spans.Full requires the input slice be immutable. @@ -262,6 +286,7 @@ func (s *subscription) listenOver(ctx context.Context, cli eventStream) { return } + log.Debug("Sending events.", zap.Int("size", len(msg.Events))) for _, m := range msg.Events { start, err := decodeKey(m.StartKey) if err != nil { @@ -275,13 +300,22 @@ func (s *subscription) listenOver(ctx context.Context, cli eventStream) { logutil.Key("event", m.EndKey), logutil.ShortError(err)) continue } - s.output <- spans.Valued{ + failpoint.Inject("subscription.listenOver.aboutToSend", func() {}) + + evt := spans.Valued{ Key: spans.Span{ StartKey: start, EndKey: end, }, Value: m.Checkpoint, } + select { + case s.output <- evt: + case <-ctx.Done(): + logutil.CL(ctx).Warn("Context canceled while sending events.", + zap.Uint64("store", storeID)) + return + } } metrics.RegionCheckpointSubscriptionEvent.WithLabelValues(strconv.Itoa(int(storeID))).Add(float64(len(msg.Events))) } @@ -291,11 +325,12 @@ func (f *FlushSubscriber) addSubscription(ctx context.Context, toStore Store) { f.subscriptions[toStore.ID] = newSubscription(toStore, f.eventsTunnel) } -func (f *FlushSubscriber) removeSubscription(toStore uint64) { +func (f *FlushSubscriber) removeSubscription(ctx context.Context, toStore uint64) { subs, ok := f.subscriptions[toStore] if ok { - log.Info("[log backup subscription manager] Removing subscription.", zap.Uint64("store", toStore)) - subs.close() + log.Info("Removing subscription.", zap.String("category", "log backup subscription manager"), + zap.Uint64("store", toStore)) + subs.close(ctx) delete(f.subscriptions, toStore) } } diff --git a/br/pkg/streamhelper/subscription_test.go b/br/pkg/streamhelper/subscription_test.go index 2341cb05dc01e..da7aa627eabd0 100644 --- a/br/pkg/streamhelper/subscription_test.go +++ b/br/pkg/streamhelper/subscription_test.go @@ -7,6 +7,7 @@ import ( "fmt" "sync" "testing" + "time" "github.com/pingcap/tidb/br/pkg/streamhelper" "github.com/pingcap/tidb/br/pkg/streamhelper/spans" @@ -32,6 +33,16 @@ func installSubscribeSupportForRandomN(c *fakeCluster, n int) { } } +func waitPendingEvents(t *testing.T, sub *streamhelper.FlushSubscriber) { + last := len(sub.Events()) + time.Sleep(100 * time.Microsecond) + require.Eventually(t, func() bool { + noProg := len(sub.Events()) == last + last = len(sub.Events()) + return noProg + }, 3*time.Second, 100*time.Millisecond) +} + func TestSubBasic(t *testing.T) { req := require.New(t) ctx := context.Background() @@ -47,6 +58,7 @@ func TestSubBasic(t *testing.T) { } sub.HandleErrors(ctx) req.NoError(sub.PendingErrors()) + waitPendingEvents(t, sub) sub.Drop() s := spans.Sorted(spans.NewFullWith(spans.Full(), 1)) for k := range sub.Events() { @@ -81,6 +93,7 @@ func TestNormalError(t *testing.T) { cp = c.advanceCheckpoints() c.flushAll() } + waitPendingEvents(t, sub) sub.Drop() s := spans.Sorted(spans.NewFullWith(spans.Full(), 1)) for k := range sub.Events() { @@ -155,6 +168,7 @@ func TestStoreRemoved(t *testing.T) { sub.HandleErrors(ctx) req.NoError(sub.PendingErrors()) + waitPendingEvents(t, sub) sub.Drop() s := spans.Sorted(spans.NewFullWith(spans.Full(), 1)) for k := range sub.Events() { @@ -188,6 +202,8 @@ func TestSomeOfStoreUnsupported(t *testing.T) { } s := spans.Sorted(spans.NewFullWith(spans.Full(), 1)) m := new(sync.Mutex) + + waitPendingEvents(t, sub) sub.Drop() for k := range sub.Events() { s.Merge(k)