From 34cb745c092af15c3b6911fe84f845ce52d7fcdb Mon Sep 17 00:00:00 2001 From: huabing zhao Date: Wed, 13 Sep 2023 14:17:45 +0800 Subject: [PATCH] purge response channel before processing a request to avoid deadlock Signed-off-by: huabing zhao --- pkg/server/delta/v3/server.go | 71 +++++++++++++++++++---------- pkg/server/delta/v3/watches.go | 11 +++-- pkg/server/delta/v3/watches_test.go | 5 -- 3 files changed, 56 insertions(+), 31 deletions(-) diff --git a/pkg/server/delta/v3/server.go b/pkg/server/delta/v3/server.go index 74d8af57b4..65fb0d3f75 100644 --- a/pkg/server/delta/v3/server.go +++ b/pkg/server/delta/v3/server.go @@ -65,7 +65,7 @@ func NewServer(ctx context.Context, config cache.ConfigWatcher, callbacks Callba return s } -func (s *server) processDelta(str stream.DeltaStream, reqCh chan *discovery.DeltaDiscoveryRequest, defaultTypeURL string) error { +func (s *server) processDelta(str stream.DeltaStream, reqCh <-chan *discovery.DeltaDiscoveryRequest, defaultTypeURL string) error { streamID := atomic.AddInt64(&s.streamCount, 1) // streamNonce holds a unique nonce for req-resp pairs per xDS stream. @@ -83,7 +83,7 @@ func (s *server) processDelta(str stream.DeltaStream, reqCh chan *discovery.Delt } }() - // Sends a response, returns the new stream nonce + // sends a response, returns the new stream nonce send := func(resp cache.DeltaResponse) (string, error) { if resp == nil { return "", errors.New("missing response") @@ -103,6 +103,44 @@ func (s *server) processDelta(str stream.DeltaStream, reqCh chan *discovery.Delt return response.Nonce, str.Send(response) } + // process a single delta response + process := func(resp cache.DeltaResponse) error { + typ := resp.GetDeltaRequest().GetTypeUrl() + if resp == deltaErrorResponse { + return status.Errorf(codes.Unavailable, typ+" watch failed") + } + + nonce, err := send(resp) + if err != nil { + return err + } + + watch := watches.deltaWatches[typ] + watch.nonce = nonce + + watch.state.SetResourceVersions(resp.GetNextVersionMap()) + watches.deltaWatches[typ] = watch + return nil + } + + // processAll purges the deltaMuxedResponses channel + processAll := func() error { + for { + select { + // We watch the multiplexed channel for incoming responses. + case resp, more := <-watches.deltaMuxedResponses: + if !more { + break + } + if err := process(resp); err != nil { + return err + } + default: + return nil + } + } + } + if s.callbacks != nil { if err := s.callbacks.OnDeltaStreamOpen(str.Context(), streamID, defaultTypeURL); err != nil { return err @@ -113,41 +151,29 @@ func (s *server) processDelta(str stream.DeltaStream, reqCh chan *discovery.Delt select { case <-s.ctx.Done(): return nil + // We watch the multiplexed channel for incoming responses. case resp, more := <-watches.deltaMuxedResponses: + // input stream ended or errored out if !more { break } - typ := resp.GetDeltaRequest().GetTypeUrl() - if resp == deltaErrorResponse { - return status.Errorf(codes.Unavailable, typ+" watch failed") - } - - nonce, err := send(resp) - if err != nil { + if err := process(resp); err != nil { return err } - - watch := watches.deltaWatches[typ] - watch.nonce = nonce - - watch.state.SetResourceVersions(resp.GetNextVersionMap()) - watches.deltaWatches[typ] = watch case req, more := <-reqCh: // input stream ended or errored out if !more { return nil } + if req == nil { return status.Errorf(codes.Unavailable, "empty request") } - // make sure responses are processed prior to new requests to avoid deadlock - if len(watches.deltaMuxedResponses) > 0 { - go func() { - reqCh <- req - }() - break + // make sure all existing responses are processed prior to new requests to avoid deadlock + if err := processAll(); err != nil { + return err } if s.callbacks != nil { @@ -192,8 +218,7 @@ func (s *server) processDelta(str stream.DeltaStream, reqCh chan *discovery.Delt s.subscribe(req.GetResourceNamesSubscribe(), &watch.state) s.unsubscribe(req.GetResourceNamesUnsubscribe(), &watch.state) - watch.responses = watches.deltaMuxedResponses - watch.cancel = s.cache.CreateDeltaWatch(req, watch.state, watch.responses) + watch.cancel = s.cache.CreateDeltaWatch(req, watch.state, watches.deltaMuxedResponses) watches.deltaWatches[typeURL] = watch } } diff --git a/pkg/server/delta/v3/watches.go b/pkg/server/delta/v3/watches.go index 839d323211..63c4c2d38d 100644 --- a/pkg/server/delta/v3/watches.go +++ b/pkg/server/delta/v3/watches.go @@ -17,6 +17,10 @@ type watches struct { // newWatches creates and initializes watches. func newWatches() watches { // deltaMuxedResponses needs a buffer to release go-routines populating it + // + // because deltaMuxedResponses can be populated by an update from the cache + // and a request from the client, we need to create the channel with a buffer + // size of 2x the number of types to avoid deadlocks. return watches{ deltaWatches: make(map[string]watch, int(types.UnknownType)), deltaMuxedResponses: make(chan cache.DeltaResponse, int(types.UnknownType)*2), @@ -28,13 +32,14 @@ func (w *watches) Cancel() { for _, watch := range w.deltaWatches { watch.Cancel() } + + close(w.deltaMuxedResponses) } // watch contains the necessary modifiables for receiving resource responses type watch struct { - responses chan cache.DeltaResponse - cancel func() - nonce string + cancel func() + nonce string state stream.StreamState } diff --git a/pkg/server/delta/v3/watches_test.go b/pkg/server/delta/v3/watches_test.go index 6113498707..cee0985ebd 100644 --- a/pkg/server/delta/v3/watches_test.go +++ b/pkg/server/delta/v3/watches_test.go @@ -5,8 +5,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - - "github.com/envoyproxy/go-control-plane/pkg/cache/v3" ) func TestDeltaWatches(t *testing.T) { @@ -14,14 +12,11 @@ func TestDeltaWatches(t *testing.T) { watches := newWatches() cancelCount := 0 - var channels []chan cache.DeltaResponse // create a few watches, and ensure that the cancel function are called and the channels are closed for i := 0; i < 5; i++ { newWatch := watch{} if i%2 == 0 { newWatch.cancel = func() { cancelCount++ } - newWatch.responses = make(chan cache.DeltaResponse) - channels = append(channels, newWatch.responses) } watches.deltaWatches[strconv.Itoa(i)] = newWatch