Skip to content

Commit

Permalink
close the connection to avoid retrying to connect
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Leung <rleungx@gmail.com>
  • Loading branch information
rleungx committed Jan 25, 2024
1 parent 9a82b47 commit 03f0fcd
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 24 deletions.
3 changes: 3 additions & 0 deletions pkg/utils/grpcutil/grpcutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ func CheckStream(ctx context.Context, cancel context.CancelFunc, done chan struc

// NeedRebuildConnection checks if the error is a connection error.
func NeedRebuildConnection(err error) bool {
if err == nil {
return false
}
return err == io.EOF ||
strings.Contains(err.Error(), codes.Unavailable.String()) || // Unavailable indicates the service is currently unavailable. This is a most likely a transient condition.
strings.Contains(err.Error(), codes.DeadlineExceeded.String()) || // DeadlineExceeded means operation expired before completion.
Expand Down
38 changes: 17 additions & 21 deletions server/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,17 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error {
forwardStream tsopb.TSO_TsoClient
forwardCtx context.Context
cancelForward context.CancelFunc
tsoStreamErr error
lastForwardedHost string
)
defer func() {
s.concurrentTSOProxyStreamings.Add(-1)
if cancelForward != nil {
cancelForward()
}
if grpcutil.NeedRebuildConnection(tsoStreamErr) {
s.closeDelegateClient(lastForwardedHost)
}
}()

maxConcurrentTSOProxyStreamings := int32(s.GetMaxConcurrentTSOProxyStreamings())
Expand Down Expand Up @@ -131,26 +135,31 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error {

forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), utils.TSOServiceName)
if !ok || len(forwardedHost) == 0 {
return errors.WithStack(ErrNotFoundTSOAddr)
tsoStreamErr = errors.WithStack(ErrNotFoundTSOAddr)
return tsoStreamErr
}
if forwardStream == nil || lastForwardedHost != forwardedHost {
if cancelForward != nil {
cancelForward()
}

s.closeDelegateClient(lastForwardedHost)
clientConn, err := s.getDelegateClient(s.ctx, forwardedHost)
if err != nil {
return errors.WithStack(err)
tsoStreamErr = errors.WithStack(err)
return tsoStreamErr

Check warning on line 150 in server/forward.go

View check run for this annotation

Codecov / codecov/patch

server/forward.go#L149-L150

Added lines #L149 - L150 were not covered by tests
}
forwardStream, forwardCtx, cancelForward, err = s.createTSOForwardStream(stream.Context(), clientConn)
if err != nil {
return errors.WithStack(err)
tsoStreamErr = errors.WithStack(err)
return tsoStreamErr

Check warning on line 155 in server/forward.go

View check run for this annotation

Codecov / codecov/patch

server/forward.go#L154-L155

Added lines #L154 - L155 were not covered by tests
}
lastForwardedHost = forwardedHost
}

tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh)
if err != nil {
tsoStreamErr = errors.WithStack(err)
return errors.WithStack(err)
}

Expand Down Expand Up @@ -363,25 +372,12 @@ func (s *GrpcServer) getDelegateClient(ctx context.Context, forwardedHost string
return conn.(*grpc.ClientConn), nil
}

func (s *GrpcServer) getForwardedHost(ctx, streamCtx context.Context, serviceName ...string) (forwardedHost string, err error) {
if s.IsAPIServiceMode() {
var ok bool
if len(serviceName) == 0 {
return "", ErrNotFoundService
}
forwardedHost, ok = s.GetServicePrimaryAddr(ctx, serviceName[0])
if !ok || len(forwardedHost) == 0 {
switch serviceName[0] {
case utils.TSOServiceName:
return "", ErrNotFoundTSOAddr
case utils.SchedulingServiceName:
return "", ErrNotFoundSchedulingAddr
}
}
} else if fh := grpcutil.GetForwardedHost(streamCtx); !s.isLocalRequest(fh) {
forwardedHost = fh
func (s *GrpcServer) closeDelegateClient(forwardedHost string) {
client, ok := s.clientConns.LoadAndDelete(forwardedHost)
if !ok {
return
}
return forwardedHost, nil
client.(*grpc.ClientConn).Close()
}

func (s *GrpcServer) isLocalRequest(host string) bool {
Expand Down
10 changes: 7 additions & 3 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,8 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
return errors.WithStack(err)
}

if forwardedHost, err := s.getForwardedHost(ctx, stream.Context(), utils.TSOServiceName); err != nil {
return err
} else if len(forwardedHost) > 0 {
forwardedHost := grpcutil.GetForwardedHost(stream.Context())
if !s.isLocalRequest(forwardedHost) {
clientConn, err := s.getDelegateClient(s.ctx, forwardedHost)
if err != nil {
return errors.WithStack(err)
Expand Down Expand Up @@ -1332,6 +1331,8 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error
if cancel != nil {
cancel()
}

s.closeDelegateClient(lastForwardedSchedulingHost)
client, err := s.getDelegateClient(s.ctx, forwardedSchedulingHost)
if err != nil {
errRegionHeartbeatClient.Inc()
Expand Down Expand Up @@ -1370,6 +1371,9 @@ func (s *GrpcServer) RegionHeartbeat(stream pdpb.PD_RegionHeartbeatServer) error
}
if err := forwardSchedulingStream.Send(schedulingpbReq); err != nil {
forwardSchedulingStream = nil
if grpcutil.NeedRebuildConnection(err) {
s.closeDelegateClient(lastForwardedSchedulingHost)

Check warning on line 1375 in server/grpc_service.go

View check run for this annotation

Codecov / codecov/patch

server/grpc_service.go#L1374-L1375

Added lines #L1374 - L1375 were not covered by tests
}
errRegionHeartbeatSend.Inc()
log.Error("failed to send request to scheduling service", zap.Error(err))
}
Expand Down

0 comments on commit 03f0fcd

Please sign in to comment.