diff --git a/pkg/registry/common/memory/ns_server.go b/pkg/registry/common/memory/ns_server.go index 1424eeeff..f57e5790c 100644 --- a/pkg/registry/common/memory/ns_server.go +++ b/pkg/registry/common/memory/ns_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,113 +18,143 @@ package memory import ( "context" - "errors" "io" + "github.com/edwarnicke/serialize" "github.com/golang/protobuf/ptypes/empty" "github.com/google/uuid" - "github.com/networkservicemesh/api/pkg/api/registry" - "github.com/edwarnicke/serialize" + "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/matchutils" ) -type networkServiceRegistryServer struct { +type memoryNSServer struct { networkServices NetworkServiceSyncMap executor serialize.Executor eventChannels map[string]chan *registry.NetworkService eventChannelSize int } -func (n *networkServiceRegistryServer) Register(ctx context.Context, ns *registry.NetworkService) (*registry.NetworkService, error) { +// NewNetworkServiceRegistryServer creates new memory based NetworkServiceRegistryServer +func NewNetworkServiceRegistryServer(options ...Option) registry.NetworkServiceRegistryServer { + s := &memoryNSServer{ + eventChannelSize: defaultEventChannelSize, + eventChannels: make(map[string]chan *registry.NetworkService), + } + for _, o := range options { + o.apply(s) + } + return s +} + +func (s *memoryNSServer) setEventChannelSize(l int) { + s.eventChannelSize = l +} + +func (s *memoryNSServer) Register(ctx context.Context, ns *registry.NetworkService) (*registry.NetworkService, error) { r, err := next.NetworkServiceRegistryServer(ctx).Register(ctx, ns) if err != nil { return nil, err } - n.networkServices.Store(r.Name, r.Clone()) - n.executor.AsyncExec(func() { - for _, ch := range n.eventChannels { - ch <- r.Clone() + + s.networkServices.Store(r.Name, r.Clone()) + + s.sendEvent(r) + + return r, nil +} + +func (s *memoryNSServer) sendEvent(event *registry.NetworkService) { + event = event.Clone() + s.executor.AsyncExec(func() { + for _, ch := range s.eventChannels { + ch <- event.Clone() } }) - return r, nil } -func (n *networkServiceRegistryServer) Find(query *registry.NetworkServiceQuery, s registry.NetworkServiceRegistry_FindServer) error { - sendAllMatches := func(ns *registry.NetworkService) error { - var err error - n.networkServices.Range(func(key string, value *registry.NetworkService) bool { - if matchutils.MatchNetworkServices(ns, value) { - err = s.Send(value.Clone()) - return err == nil +func (s *memoryNSServer) Find(query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer) error { + if !query.Watch { + for _, ns := range s.allMatches(query) { + if err := server.Send(ns); err != nil { + return err } - return true - }) + } + return next.NetworkServiceRegistryServer(server.Context()).Find(query, server) + } + + eventCh := make(chan *registry.NetworkService, s.eventChannelSize) + id := uuid.New().String() + + s.executor.AsyncExec(func() { + s.eventChannels[id] = eventCh + for _, entity := range s.allMatches(query) { + eventCh <- entity + } + }) + defer s.closeEventChannel(id, eventCh) + + var err error + for ; err == nil; err = s.receiveEvent(query, server, eventCh) { + } + if err != io.EOF { return err } - if query.Watch { - eventCh := make(chan *registry.NetworkService, n.eventChannelSize) - id := uuid.New().String() - n.executor.AsyncExec(func() { - n.eventChannels[id] = eventCh - }) - defer n.executor.AsyncExec(func() { - delete(n.eventChannels, id) - }) - err := sendAllMatches(query.NetworkService) - if err != nil { - return err + return next.NetworkServiceRegistryServer(server.Context()).Find(query, server) +} + +func (s *memoryNSServer) allMatches(query *registry.NetworkServiceQuery) (matches []*registry.NetworkService) { + s.networkServices.Range(func(_ string, ns *registry.NetworkService) bool { + if matchutils.MatchNetworkServices(query.NetworkService, ns) { + matches = append(matches, ns.Clone()) } - notifyChannel := func() error { - select { - case <-s.Context().Done(): - return io.EOF - case event := <-eventCh: - if matchutils.MatchNetworkServices(query.NetworkService, event) { - if s.Context().Err() != nil { - return io.EOF - } - if err := s.Send(event); err != nil { - return err - } - } - return nil - } + return true + }) + return matches +} + +func (s *memoryNSServer) closeEventChannel(id string, eventCh <-chan *registry.NetworkService) { + ctx, cancel := context.WithCancel(context.Background()) + + s.executor.AsyncExec(func() { + delete(s.eventChannels, id) + cancel() + }) + + for { + select { + case <-ctx.Done(): + return + case <-eventCh: } - for { - err := notifyChannel() - if errors.Is(err, io.EOF) { - break + } +} + +func (s *memoryNSServer) receiveEvent( + query *registry.NetworkServiceQuery, + server registry.NetworkServiceRegistry_FindServer, + eventCh <-chan *registry.NetworkService, +) error { + select { + case <-server.Context().Done(): + return io.EOF + case event := <-eventCh: + if matchutils.MatchNetworkServices(query.NetworkService, event) { + if server.Context().Err() != nil { + return io.EOF } - if err != nil { + if err := server.Send(event); err != nil { return err } } - } else if err := sendAllMatches(query.NetworkService); err != nil { - return err + return nil } - return next.NetworkServiceRegistryServer(s.Context()).Find(query, s) } -func (n *networkServiceRegistryServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*empty.Empty, error) { - n.networkServices.Delete(ns.Name) - return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns) -} +func (s *memoryNSServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*empty.Empty, error) { + s.networkServices.Delete(ns.Name) -func (n *networkServiceRegistryServer) setEventChannelSize(l int) { - n.eventChannelSize = l -} - -// NewNetworkServiceRegistryServer creates new memory based NetworkServiceRegistryServer -func NewNetworkServiceRegistryServer(options ...Option) registry.NetworkServiceRegistryServer { - r := &networkServiceRegistryServer{ - eventChannelSize: defaultEventChannelSize, - eventChannels: make(map[string]chan *registry.NetworkService), - } - for _, o := range options { - o.apply(r) - } - return r + return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns) } diff --git a/pkg/registry/common/memory/ns_server_test.go b/pkg/registry/common/memory/ns_server_test.go index 93e6fefe9..be40783c5 100644 --- a/pkg/registry/common/memory/ns_server_test.go +++ b/pkg/registry/common/memory/ns_server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,14 +18,20 @@ package memory_test import ( "context" + "fmt" + "sync" "testing" + "time" - "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/protobuf/proto" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/networkservicemesh/sdk/pkg/registry/common/memory" + "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" ) @@ -106,3 +112,129 @@ func TestNetworkServiceRegistryServer_RegisterAndFindWatch(t *testing.T) { require.NoError(t, err) require.True(t, proto.Equal(expected, <-ch)) } + +func TestNetworkServiceRegistryServer_DataRace(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := memory.NewNetworkServiceRegistryServer() + + _, err := s.Register(ctx, ®istry.NetworkService{Name: "ns"}) + require.NoError(t, err) + + c := adapters.NetworkServiceServerToClient(s) + + var wgStart, wgEnd sync.WaitGroup + for i := 0; i < 10; i++ { + wgStart.Add(1) + wgEnd.Add(1) + go func() { + defer wgEnd.Done() + + findCtx, findCancel := context.WithTimeout(ctx, time.Second) + defer findCancel() + + stream, err := c.Find(findCtx, ®istry.NetworkServiceQuery{ + NetworkService: ®istry.NetworkService{Name: "ns"}, + Watch: true, + }) + assert.NoError(t, err) + + _, err = stream.Recv() + assert.NoError(t, err) + + wgStart.Done() + + for j := 0; j < 100; j++ { + ns, err := stream.Recv() + assert.NoError(t, err) + + ns.Name = "" + } + }() + } + wgStart.Wait() + + for i := 0; i < 100; i++ { + _, err := s.Register(ctx, ®istry.NetworkService{Name: fmt.Sprintf("ns-%d", i)}) + require.NoError(t, err) + } + + wgEnd.Wait() +} + +func TestNetworkServiceRegistryServer_SlowReceiver(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := memory.NewNetworkServiceRegistryServer() + + c := adapters.NetworkServiceServerToClient(s) + + findCtx, findCancel := context.WithCancel(ctx) + + stream, err := c.Find(findCtx, ®istry.NetworkServiceQuery{ + NetworkService: ®istry.NetworkService{Name: "ns"}, + Watch: true, + }) + require.NoError(t, err) + + for i := 0; i < 1000; i++ { + _, err = s.Register(ctx, ®istry.NetworkService{Name: fmt.Sprintf("ns-%d", i)}) + require.NoError(t, err) + } + + ignoreCurrent := goleak.IgnoreCurrent() + + _, err = stream.Recv() + require.NoError(t, err) + + findCancel() + + require.Eventually(t, func() bool { + return goleak.Find(ignoreCurrent) == nil + }, 100*time.Millisecond, time.Millisecond) +} + +func TestNetworkServiceRegistryServer_ShouldReceiveAllRegisters(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := memory.NewNetworkServiceRegistryServer() + + c := adapters.NetworkServiceServerToClient(s) + + var wg sync.WaitGroup + for i := 0; i < 300; i++ { + wg.Add(1) + name := fmt.Sprintf("ns-%d", i) + + go func() { + _, err := s.Register(ctx, ®istry.NetworkService{Name: name}) + require.NoError(t, err) + }() + + go func() { + defer wg.Done() + + findCtx, findCancel := context.WithTimeout(ctx, time.Second) + defer findCancel() + + stream, err := c.Find(findCtx, ®istry.NetworkServiceQuery{ + NetworkService: ®istry.NetworkService{Name: name}, + Watch: true, + }) + assert.NoError(t, err) + + _, err = stream.Recv() + assert.NoError(t, err) + }() + } + wg.Wait() +} diff --git a/pkg/registry/common/memory/nse_server.go b/pkg/registry/common/memory/nse_server.go index bcc1dd622..a9a4ae79f 100644 --- a/pkg/registry/common/memory/nse_server.go +++ b/pkg/registry/common/memory/nse_server.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,127 +18,149 @@ package memory import ( "context" - "errors" "io" - "github.com/google/uuid" - + "github.com/edwarnicke/serialize" + "github.com/golang/protobuf/ptypes/empty" "github.com/golang/protobuf/ptypes/timestamp" + "github.com/google/uuid" - "github.com/golang/protobuf/ptypes/empty" "github.com/networkservicemesh/api/pkg/api/registry" - "github.com/edwarnicke/serialize" - "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/tools/matchutils" ) -type networkServiceEndpointRegistryServer struct { +type memoryNSEServer struct { networkServiceEndpoints NetworkServiceEndpointSyncMap executor serialize.Executor eventChannels map[string]chan *registry.NetworkServiceEndpoint eventChannelSize int } -func (n *networkServiceEndpointRegistryServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { +// NewNetworkServiceEndpointRegistryServer creates new memory based NetworkServiceEndpointRegistryServer +func NewNetworkServiceEndpointRegistryServer(options ...Option) registry.NetworkServiceEndpointRegistryServer { + s := &memoryNSEServer{ + eventChannelSize: defaultEventChannelSize, + eventChannels: make(map[string]chan *registry.NetworkServiceEndpoint), + } + for _, o := range options { + o.apply(s) + } + return s +} + +func (s *memoryNSEServer) setEventChannelSize(l int) { + s.eventChannelSize = l +} + +func (s *memoryNSEServer) Register(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*registry.NetworkServiceEndpoint, error) { r, err := next.NetworkServiceEndpointRegistryServer(ctx).Register(ctx, nse) if err != nil { return nil, err } - n.networkServiceEndpoints.Store(r.Name, r.Clone()) - n.sendEvent(r.Clone()) + + s.networkServiceEndpoints.Store(r.Name, r.Clone()) + + s.sendEvent(r) + return r, err } -func (n *networkServiceEndpointRegistryServer) Find(query *registry.NetworkServiceEndpointQuery, s registry.NetworkServiceEndpointRegistry_FindServer) error { - sendAllMatches := func(ns *registry.NetworkServiceEndpoint) error { - var err error - n.networkServiceEndpoints.Range(func(key string, value *registry.NetworkServiceEndpoint) bool { - if matchutils.MatchNetworkServiceEndpoints(ns, value) { - err = s.Send(value.Clone()) - return err == nil - } - return true - }) - return err - } - if query.Watch { - eventCh := make(chan *registry.NetworkServiceEndpoint, n.eventChannelSize) - id := uuid.New().String() - n.executor.AsyncExec(func() { - n.eventChannels[id] = eventCh - }) - defer n.executor.AsyncExec(func() { - delete(n.eventChannels, id) - }) - err := sendAllMatches(query.NetworkServiceEndpoint) - if err != nil { - return err - } - notifyChannel := func() error { - select { - case <-s.Context().Done(): - return io.EOF - case event := <-eventCh: - if matchutils.MatchNetworkServiceEndpoints(query.NetworkServiceEndpoint, event) { - if s.Context().Err() != nil { - return io.EOF - } - if err := s.Send(event); err != nil { - return err - } - } - return nil - } +func (s *memoryNSEServer) sendEvent(event *registry.NetworkServiceEndpoint) { + event = event.Clone() + s.executor.AsyncExec(func() { + for _, ch := range s.eventChannels { + ch <- event.Clone() } - for { - err := notifyChannel() - if errors.Is(err, io.EOF) { - break - } - if err != nil { + }) +} + +func (s *memoryNSEServer) Find(query *registry.NetworkServiceEndpointQuery, server registry.NetworkServiceEndpointRegistry_FindServer) error { + if !query.Watch { + for _, ns := range s.allMatches(query) { + if err := server.Send(ns); err != nil { return err } } - } else if err := sendAllMatches(query.NetworkServiceEndpoint); err != nil { + return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) + } + + eventCh := make(chan *registry.NetworkServiceEndpoint, s.eventChannelSize) + id := uuid.New().String() + + s.executor.AsyncExec(func() { + s.eventChannels[id] = eventCh + for _, entity := range s.allMatches(query) { + eventCh <- entity + } + }) + defer s.closeEventChannel(id, eventCh) + + var err error + for ; err == nil; err = s.receiveEvent(query, server, eventCh) { + } + if err != io.EOF { return err } - return next.NetworkServiceEndpointRegistryServer(s.Context()).Find(query, s) + return next.NetworkServiceEndpointRegistryServer(server.Context()).Find(query, server) } -func (n *networkServiceEndpointRegistryServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { - n.networkServiceEndpoints.Delete(nse.Name) - if nse.ExpirationTime == nil { - nse.ExpirationTime = ×tamp.Timestamp{} - } - <-n.executor.AsyncExec(func() { - nse.ExpirationTime.Seconds = -1 +func (s *memoryNSEServer) allMatches(query *registry.NetworkServiceEndpointQuery) (matches []*registry.NetworkServiceEndpoint) { + s.networkServiceEndpoints.Range(func(_ string, nse *registry.NetworkServiceEndpoint) bool { + if matchutils.MatchNetworkServiceEndpoints(query.NetworkServiceEndpoint, nse) { + matches = append(matches, nse.Clone()) + } + return true }) - n.sendEvent(nse) - return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) + return matches } -func (n *networkServiceEndpointRegistryServer) setEventChannelSize(l int) { - n.eventChannelSize = l -} +func (s *memoryNSEServer) closeEventChannel(id string, eventCh <-chan *registry.NetworkServiceEndpoint) { + ctx, cancel := context.WithCancel(context.Background()) -func (n *networkServiceEndpointRegistryServer) sendEvent(nse *registry.NetworkServiceEndpoint) { - n.executor.AsyncExec(func() { - for _, ch := range n.eventChannels { - ch <- nse - } + s.executor.AsyncExec(func() { + delete(s.eventChannels, id) + cancel() }) + + for { + select { + case <-ctx.Done(): + return + case <-eventCh: + } + } } -// NewNetworkServiceEndpointRegistryServer creates new memory based NetworkServiceEndpointRegistryServer -func NewNetworkServiceEndpointRegistryServer(options ...Option) registry.NetworkServiceEndpointRegistryServer { - r := &networkServiceEndpointRegistryServer{ - eventChannelSize: defaultEventChannelSize, - eventChannels: make(map[string]chan *registry.NetworkServiceEndpoint), +func (s *memoryNSEServer) receiveEvent( + query *registry.NetworkServiceEndpointQuery, + server registry.NetworkServiceEndpointRegistry_FindServer, + eventCh <-chan *registry.NetworkServiceEndpoint, +) error { + select { + case <-server.Context().Done(): + return io.EOF + case event := <-eventCh: + if matchutils.MatchNetworkServiceEndpoints(query.NetworkServiceEndpoint, event) { + if server.Context().Err() != nil { + return io.EOF + } + if err := server.Send(event); err != nil { + return err + } + } + return nil } - for _, o := range options { - o.apply(r) +} + +func (s *memoryNSEServer) Unregister(ctx context.Context, nse *registry.NetworkServiceEndpoint) (*empty.Empty, error) { + s.networkServiceEndpoints.Delete(nse.Name) + + nse.ExpirationTime = ×tamp.Timestamp{ + Seconds: -1, } - return r + s.sendEvent(nse) + + return next.NetworkServiceEndpointRegistryServer(ctx).Unregister(ctx, nse) } diff --git a/pkg/registry/common/memory/nse_server_test.go b/pkg/registry/common/memory/nse_server_test.go index c7323e9eb..2ea66d8bf 100644 --- a/pkg/registry/common/memory/nse_server_test.go +++ b/pkg/registry/common/memory/nse_server_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2020 Doc.ai and/or its affiliates. +// Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -18,14 +18,20 @@ package memory_test import ( "context" + "fmt" + "io" + "sync" "testing" + "time" "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" "google.golang.org/protobuf/proto" "github.com/networkservicemesh/sdk/pkg/registry/common/memory" + "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" "github.com/networkservicemesh/sdk/pkg/registry/core/next" "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" ) @@ -177,6 +183,188 @@ func TestNetworkServiceEndpointRegistryServer_RegisterAndFindByLabelWatch(t *tes require.True(t, proto.Equal(expected, <-ch)) } +func TestNetworkServiceEndpointRegistryServer_DataRace(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := memory.NewNetworkServiceEndpointRegistryServer() + + _, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: "nse"}) + require.NoError(t, err) + + c := adapters.NetworkServiceEndpointServerToClient(s) + + var wgStart, wgEnd sync.WaitGroup + for i := 0; i < 10; i++ { + wgStart.Add(1) + wgEnd.Add(1) + go func() { + defer wgEnd.Done() + + findCtx, findCancel := context.WithTimeout(ctx, time.Second) + defer findCancel() + + stream, err := c.Find(findCtx, ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: "nse"}, + Watch: true, + }) + assert.NoError(t, err) + + _, err = stream.Recv() + assert.NoError(t, err) + + wgStart.Done() + + for j := 0; j < 100; j++ { + nse, err := stream.Recv() + assert.NoError(t, err) + + nse.Name = "" + } + }() + } + wgStart.Wait() + + for i := 0; i < 100; i++ { + _, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: fmt.Sprintf("nse-%d", i)}) + require.NoError(t, err) + } + + wgEnd.Wait() +} + +func TestNetworkServiceEndpointRegistryServer_SlowReceiver(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := memory.NewNetworkServiceEndpointRegistryServer() + + c := adapters.NetworkServiceEndpointServerToClient(s) + + findCtx, findCancel := context.WithCancel(ctx) + + stream, err := c.Find(findCtx, ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: "nse"}, + Watch: true, + }) + require.NoError(t, err) + + for i := 0; i < 1000; i++ { + _, err = s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: fmt.Sprintf("nse-%d", i)}) + require.NoError(t, err) + } + + ignoreCurrent := goleak.IgnoreCurrent() + + _, err = stream.Recv() + require.NoError(t, err) + + findCancel() + + require.Eventually(t, func() bool { + return goleak.Find(ignoreCurrent) == nil + }, 100*time.Millisecond, time.Millisecond) +} + +func TestNetworkServiceEndpointRegistryServer_ShouldReceiveAllRegisters(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := memory.NewNetworkServiceEndpointRegistryServer() + + c := adapters.NetworkServiceEndpointServerToClient(s) + + var wg sync.WaitGroup + for i := 0; i < 300; i++ { + wg.Add(1) + name := fmt.Sprintf("nse-%d", i) + + go func() { + _, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: name}) + require.NoError(t, err) + }() + + go func() { + defer wg.Done() + + findCtx, findCancel := context.WithTimeout(ctx, time.Second) + defer findCancel() + + stream, err := c.Find(findCtx, ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: name}, + Watch: true, + }) + assert.NoError(t, err) + + _, err = stream.Recv() + assert.NoError(t, err) + }() + } + wg.Wait() +} + +func TestNetworkServiceEndpointRegistryServer_ShouldReceiveAllUnregisters(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := memory.NewNetworkServiceEndpointRegistryServer() + + for i := 0; i < 300; i++ { + _, err := s.Register(ctx, ®istry.NetworkServiceEndpoint{Name: fmt.Sprintf("nse-%d", i)}) + require.NoError(t, err) + } + + c := adapters.NetworkServiceEndpointServerToClient(s) + + var wg sync.WaitGroup + for i := 0; i < 300; i++ { + wg.Add(1) + name := fmt.Sprintf("nse-%d", i) + + go func() { + _, err := s.Unregister(ctx, ®istry.NetworkServiceEndpoint{Name: name}) + assert.NoError(t, err) + }() + + go func() { + defer wg.Done() + + findCtx, findCancel := context.WithTimeout(ctx, time.Second) + defer findCancel() + + stream, err := c.Find(findCtx, ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{Name: name}, + Watch: true, + }) + assert.NoError(t, err) + + exists := false + for err == nil { + var nse *registry.NetworkServiceEndpoint + nse, err = stream.Recv() + switch { + case err != nil: + assert.Equal(t, io.EOF, err) + case nse.ExpirationTime != nil && nse.ExpirationTime.Seconds < 0: + return + default: + exists = true + } + } + assert.False(t, exists) + }() + } + wg.Wait() +} + func createLabeledNSE1() *registry.NetworkServiceEndpoint { labels := map[string]*registry.NetworkServiceLabels{ "Service1": {