diff --git a/dm/master/scheduler/scheduler.go b/dm/master/scheduler/scheduler.go index ff0389b35e..eadedf69c5 100644 --- a/dm/master/scheduler/scheduler.go +++ b/dm/master/scheduler/scheduler.go @@ -166,11 +166,16 @@ func NewScheduler(pLogger *log.Logger, securityCfg config.Security) *Scheduler { // Start starts the scheduler for work. // NOTE: for logic errors, it should start without returning errors (but report via metrics or log) so that the user can fix them. -func (s *Scheduler) Start(pCtx context.Context, etcdCli *clientv3.Client) error { +func (s *Scheduler) Start(pCtx context.Context, etcdCli *clientv3.Client) (err error) { s.logger.Info("the scheduler is starting") s.mu.Lock() - defer s.mu.Unlock() + defer func() { + if err != nil { + s.CloseAllWorkers() + } + s.mu.Unlock() + }() if s.started { return terror.ErrSchedulerStarted.Generate() @@ -180,7 +185,7 @@ func (s *Scheduler) Start(pCtx context.Context, etcdCli *clientv3.Client) error s.reset() // reset previous status. // recover previous status from etcd. - err := s.recoverSources(etcdCli) + err = s.recoverSources(etcdCli) if err != nil { return err } @@ -192,7 +197,8 @@ func (s *Scheduler) Start(pCtx context.Context, etcdCli *clientv3.Client) error if err != nil { return err } - rev, err := s.recoverWorkersBounds(etcdCli) + var rev int64 + rev, err = s.recoverWorkersBounds(etcdCli) if err != nil { return err } @@ -228,6 +234,7 @@ func (s *Scheduler) Close() { s.cancel() s.cancel = nil } + s.CloseAllWorkers() s.mu.Unlock() // need to wait for goroutines to return which may hold the mutex. @@ -239,6 +246,13 @@ func (s *Scheduler) Close() { s.logger.Info("the scheduler has closed") } +// CloseAllWorkers closes all the scheduler's workers. +func (s *Scheduler) CloseAllWorkers() { + for _, worker := range s.workers { + worker.Close() + } +} + // AddSourceCfg adds the upstream source config to the cluster. // NOTE: please verify the config before call this. func (s *Scheduler) AddSourceCfg(cfg config.SourceConfig) error { @@ -1249,6 +1263,10 @@ func (s *Scheduler) recoverWorkersBounds(cli *clientv3.Client) (int64, error) { } } + failpoint.Inject("failToRecoverWorkersBounds", func(_ failpoint.Value) { + log.L().Info("mock failure", zap.String("failpoint", "failToRecoverWorkersBounds")) + failpoint.Return(0, errors.New("failToRecoverWorkersBounds")) + }) // 5. delete invalid source bound info in etcd if len(sbm) > 0 { invalidSourceBounds := make([]string, 0, len(sbm)) diff --git a/dm/master/scheduler/scheduler_test.go b/dm/master/scheduler/scheduler_test.go index 55973c757b..75581797d4 100644 --- a/dm/master/scheduler/scheduler_test.go +++ b/dm/master/scheduler/scheduler_test.go @@ -15,6 +15,7 @@ package scheduler import ( "context" + "fmt" "sync" "testing" "time" @@ -26,6 +27,7 @@ import ( "go.etcd.io/etcd/integration" "github.com/pingcap/dm/dm/config" + "github.com/pingcap/dm/dm/master/workerrpc" "github.com/pingcap/dm/dm/pb" "github.com/pingcap/dm/pkg/ha" "github.com/pingcap/dm/pkg/log" @@ -1224,3 +1226,44 @@ func (t *testScheduler) TestStartStopSource(c *C) { c.Assert(err, IsNil) c.Assert(workers, HasLen, 0) } + +func checkAllWorkersClosed(c *C, s *Scheduler, closed bool) { + for _, worker := range s.workers { + cli, ok := worker.cli.(*workerrpc.GRPCClient) + c.Assert(ok, IsTrue) + c.Assert(cli.Closed(), Equals, closed) + } +} + +func (t *testScheduler) TestCloseAllWorkers(c *C) { + var ( + logger = log.L() + s = NewScheduler(&logger, config.Security{}) + names []string + ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + for i := 1; i < 4; i++ { + names = append(names, fmt.Sprintf("worker%d", i)) + } + + for i, name := range names { + info := ha.NewWorkerInfo(name, fmt.Sprintf("127.0.0.1:%d", 50801+i)) + _, err := ha.PutWorkerInfo(etcdTestCli, info) + c.Assert(err, IsNil) + } + + c.Assert(failpoint.Enable("github.com/pingcap/dm/dm/master/scheduler/failToRecoverWorkersBounds", "return"), IsNil) + // Test closed when fail to start + c.Assert(s.Start(ctx, etcdTestCli), ErrorMatches, "failToRecoverWorkersBounds") + c.Assert(s.workers, HasLen, 3) + checkAllWorkersClosed(c, s, true) + c.Assert(failpoint.Disable("github.com/pingcap/dm/dm/master/scheduler/failToRecoverWorkersBounds"), IsNil) + + s.workers = map[string]*Worker{} + c.Assert(s.Start(ctx, etcdTestCli), IsNil) + checkAllWorkersClosed(c, s, false) + s.Close() + c.Assert(s.workers, HasLen, 3) + checkAllWorkersClosed(c, s, true) +} diff --git a/dm/master/server.go b/dm/master/server.go index c04e17c755..8e400d1d43 100644 --- a/dm/master/server.go +++ b/dm/master/server.go @@ -1883,10 +1883,10 @@ func (s *Server) OperateSchema(ctx context.Context, req *pb.OperateSchemaRequest }, nil } -func (s *Server) createMasterClientByName(ctx context.Context, name string) (pb.MasterClient, error) { +func (s *Server) createMasterClientByName(ctx context.Context, name string) (pb.MasterClient, *grpc.ClientConn, error) { listResp, err := s.etcdClient.MemberList(ctx) if err != nil { - return nil, err + return nil, nil, err } clientURLs := []string{} for _, m := range listResp.Members { @@ -1898,11 +1898,11 @@ func (s *Server) createMasterClientByName(ctx context.Context, name string) (pb. } } if len(clientURLs) == 0 { - return nil, errors.New("master not found") + return nil, nil, errors.New("master not found") } tls, err := toolutils.NewTLS(s.cfg.SSLCA, s.cfg.SSLCert, s.cfg.SSLKey, s.cfg.AdvertiseAddr, s.cfg.CertAllowedCN) if err != nil { - return nil, err + return nil, nil, err } var conn *grpc.ClientConn @@ -1911,12 +1911,12 @@ func (s *Server) createMasterClientByName(ctx context.Context, name string) (pb. conn, err = grpc.Dial(clientURL, tls.ToGRPCDialOption(), grpc.WithBackoffMaxDelay(3*time.Second)) if err == nil { masterClient := pb.NewMasterClient(conn) - return masterClient, nil + return masterClient, conn, nil } log.L().Error("can not dial to master", zap.String("name", name), zap.String("client url", clientURL), log.ShortError(err)) } // return last err - return nil, err + return nil, nil, err } // GetMasterCfg implements MasterServer.GetMasterCfg. @@ -1971,12 +1971,13 @@ func (s *Server) GetCfg(ctx context.Context, req *pb.GetCfgRequest) (*pb.GetCfgR return resp2, nil } - masterClient, err := s.createMasterClientByName(ctx, req.Name) + masterClient, grpcConn, err := s.createMasterClientByName(ctx, req.Name) if err != nil { resp2.Msg = err.Error() // nolint:nilerr return resp2, nil } + defer grpcConn.Close() masterResp, err := masterClient.GetMasterCfg(ctx, &pb.GetMasterCfgRequest{}) if err != nil { resp2.Msg = err.Error() diff --git a/dm/master/workerrpc/rawgrpc.go b/dm/master/workerrpc/rawgrpc.go index 6e704b5676..55a14fec7e 100644 --- a/dm/master/workerrpc/rawgrpc.go +++ b/dm/master/workerrpc/rawgrpc.go @@ -96,6 +96,11 @@ func (c *GRPCClient) Close() error { return nil } +// Closed returns whether this grpc conn is closed. only used for test now +func (c *GRPCClient) Closed() bool { + return c.closed.Load() +} + func callRPC(ctx context.Context, client pb.WorkerClient, req *Request) (*Response, error) { resp := &Response{} resp.Type = req.Type