Skip to content

Commit

Permalink
balancer: fix logic to prevent producer streams before READY is repor…
Browse files Browse the repository at this point in the history
…ted (#7651)
  • Loading branch information
dfawley authored Sep 20, 2024
1 parent 6c48e47 commit 8ea3460
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 14 deletions.
7 changes: 3 additions & 4 deletions balancer_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ type acBalancerWrapper struct {

// updateState is invoked by grpc to push a subConn state update to the
// underlying balancer.
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) {
func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error, readyChan chan struct{}) {
acbw.ccb.serializer.TrySchedule(func(ctx context.Context) {
if ctx.Err() != nil || acbw.ccb.balancer == nil {
return
Expand All @@ -278,12 +278,11 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolve
acbw.ac.mu.Lock()
defer acbw.ac.mu.Unlock()
if s == connectivity.Ready {
// When changing states to READY, reset stateReadyChan. Wait until
// When changing states to READY, close stateReadyChan. Wait until
// after we notify the LB policy's listener(s) in order to prevent
// ac.getTransport() from unblocking before the LB policy starts
// tracking the subchannel as READY.
close(acbw.ac.stateReadyChan)
acbw.ac.stateReadyChan = make(chan struct{})
close(readyChan)
}
})
}
Expand Down
22 changes: 16 additions & 6 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1193,14 +1193,22 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error)
if ac.state == s {
return
}
if ac.state == connectivity.Ready {
// When leaving ready, re-create the ready channel.
ac.stateReadyChan = make(chan struct{})
}
if s == connectivity.Shutdown {
// Wake any producer waiting to create a stream on the transport.
close(ac.stateReadyChan)
}
ac.state = s
ac.channelz.ChannelMetrics.State.Store(&s)
if lastErr == nil {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v", s)
} else {
channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr)
}
ac.acbw.updateState(s, ac.curAddr, lastErr)
ac.acbw.updateState(s, ac.curAddr, lastErr, ac.stateReadyChan)
}

// adjustParams updates parameters used to create transports upon
Expand Down Expand Up @@ -1510,18 +1518,20 @@ func (ac *addrConn) getReadyTransport() transport.ClientTransport {
func (ac *addrConn) getTransport(ctx context.Context) (transport.ClientTransport, error) {
for ctx.Err() == nil {
ac.mu.Lock()
t, state, sc := ac.transport, ac.state, ac.stateReadyChan
t, state, readyChan := ac.transport, ac.state, ac.stateReadyChan
ac.mu.Unlock()
if state == connectivity.Ready {
return t, nil
}
if state == connectivity.Shutdown {
// Return an error immediately in only this case since a connection
// will never occur.
return nil, status.Errorf(codes.Unavailable, "SubConn shutting down")
}

select {
case <-ctx.Done():
case <-sc:
case <-readyChan:
if state == connectivity.Ready {
return t, nil
}
}
}
return nil, status.FromContextError(ctx.Err()).Err()
Expand Down
20 changes: 16 additions & 4 deletions producer_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (*producerBuilder) Build(cci any) (balancer.Producer, func()) {
}
}

func (p *producer) TestStreamStart(t *testing.T, streamStarted chan<- struct{}) {
func (p *producer) testStreamStart(t *testing.T, streamStarted chan<- struct{}) {
go func() {
defer close(p.stopped)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
Expand All @@ -68,8 +68,11 @@ var producerBuilderSingleton = &producerBuilder{}
// TestProducerStreamStartsAfterReady ensures producer streams only start after
// the subchannel reports as READY to the LB policy.
func (s) TestProducerStreamStartsAfterReady(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
name := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "")
producerCh := make(chan balancer.Producer)
var producerClose func()
streamStarted := make(chan struct{})
done := make(chan struct{})
bf := stub.BalancerFuncs{
Expand All @@ -90,7 +93,8 @@ func (s) TestProducerStreamStartsAfterReady(t *testing.T) {
if err != nil {
return err
}
producer, _ := sc.GetOrBuildProducer(producerBuilderSingleton)
var producer balancer.Producer
producer, producerClose = sc.GetOrBuildProducer(producerBuilderSingleton)
producerCh <- producer
sc.Connect()
return nil
Expand Down Expand Up @@ -119,7 +123,15 @@ func (s) TestProducerStreamStartsAfterReady(t *testing.T) {

go cc.Connect()
p := <-producerCh
p.(*producer).TestStreamStart(t, streamStarted)
p.(*producer).testStreamStart(t, streamStarted)

<-done
select {
case <-done:
// Wait for the stream to start before exiting; otherwise the ClientConn
// will close and cause stream creation to fail.
<-streamStarted
producerClose()
case <-ctx.Done():
t.Error("Timed out waiting for test to complete")
}
}

0 comments on commit 8ea3460

Please sign in to comment.